From cbc96af526be3484683bd424bcebd9b8420026f5 Mon Sep 17 00:00:00 2001
From: Gabriel Erzse <gabriel.erzse@redis.com>
Date: Wed, 29 May 2024 19:02:56 +0300
Subject: [PATCH 1/3] Make it possible to return raw fields in search results

In some use cases (e.g. vector search), the fields in the search results
must not be decoded. Add an optional parameter to the search methods,
which makes it possible to disable decoding.

In order to not break the existing functionality, the parameter has the
default value set to True, i.e. it will decode like before.
---
 dev_requirements.txt              |  1 +
 redis/commands/search/commands.py | 63 ++++++++++++++++++++++---------
 redis/commands/search/result.py   | 28 +++++++++-----
 tests/test_search.py              | 35 +++++++++++++++++
 4 files changed, 99 insertions(+), 28 deletions(-)

diff --git a/dev_requirements.txt b/dev_requirements.txt
index 3715599af0..f394314db2 100644
--- a/dev_requirements.txt
+++ b/dev_requirements.txt
@@ -4,6 +4,7 @@ flake8==5.0.4
 flake8-isort==6.0.0
 flynt~=0.69.0
 mock==4.0.3
+numpy>=1.24.4
 packaging>=20.4
 pytest==7.2.0
 pytest-timeout==2.1.0
diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py
index 2df2b5a754..b308c68f5f 100644
--- a/redis/commands/search/commands.py
+++ b/redis/commands/search/commands.py
@@ -80,6 +80,7 @@ def _parse_search(self, res, **kwargs):
             duration=kwargs["duration"],
             has_payload=kwargs["query"]._with_payloads,
             with_scores=kwargs["query"]._with_scores,
+            decode_fields=kwargs["decode_fields"],
         )
 
     def _parse_aggregate(self, res, **kwargs):
@@ -484,18 +485,27 @@ def search(
         self,
         query: Union[str, Query],
         query_params: Union[Dict[str, Union[str, int, float, bytes]], None] = None,
+        decode_fields: bool = True,
     ):
         """
-        Search the index for a given query, and return a result of documents
+        Search the index for a given query, and return a result of documents.
+
+        Args:
+            query: The search query. This can be a simple text string for basic queries,
+                   or a Query object for more complex queries. Refer to RediSearch's
+                   documentation for details on the query format.
+            query_params: Additional parameters for the query. These parameters are used
+                          to replace placeholders in the query string. This is useful
+                          for safely including user input in a search query.
+            decode_fields: If `True`, which is the default, decodes the fields in the
+                           search results. If `False`, fields are returned in their raw
+                           binary form.
 
-        ### Parameters
-
-        - **query**: the search query. Either a text for simple queries with
-                     default parameters, or a Query object for complex queries.
-                     See RediSearch's documentation on query format
+        Returns:
+            A result set of documents matching the query.
 
-        For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
-        """  # noqa
+        For more information see https://redis.io/commands/ft.search
+        """
         args, query = self._mk_query_args(query, query_params=query_params)
         st = time.time()
         res = self.execute_command(SEARCH_CMD, *args)
@@ -504,7 +514,11 @@ def search(
             return res
 
         return self._parse_results(
-            SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
+            SEARCH_CMD,
+            res,
+            query=query,
+            duration=(time.time() - st) * 1000.0,
+            decode_fields=decode_fields,
         )
 
     def explain(
@@ -911,18 +925,27 @@ async def search(
         self,
         query: Union[str, Query],
         query_params: Dict[str, Union[str, int, float]] = None,
+        decode_fields: bool = True,
     ):
         """
-        Search the index for a given query, and return a result of documents
+        Search the index for a given query, and return a result of documents.
+
+        Args:
+            query: The search query. This can be a simple text string for basic queries,
+                   or a Query object for more complex queries. Refer to RediSearch's
+                   documentation for details on the query format.
+            query_params: Additional parameters for the query. These parameters are used
+                          to replace placeholders in the query string. This is useful
+                          for safely including user input in a search query.
+            decode_fields: If `True`, which is the default, decodes the fields in the
+                           search results. If `False`, fields are returned in their raw
+                           binary form.
 
-        ### Parameters
-
-        - **query**: the search query. Either a text for simple queries with
-                     default parameters, or a Query object for complex queries.
-                     See RediSearch's documentation on query format
+        Returns:
+            A result set of documents matching the query.
 
-        For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
-        """  # noqa
+        For more information see https://redis.io/commands/ft.search
+        """
         args, query = self._mk_query_args(query, query_params=query_params)
         st = time.time()
         res = await self.execute_command(SEARCH_CMD, *args)
@@ -931,7 +954,11 @@ async def search(
             return res
 
         return self._parse_results(
-            SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
+            SEARCH_CMD,
+            res,
+            query=query,
+            duration=(time.time() - st) * 1000.0,
+            decode_fields=decode_fields,
         )
 
     async def aggregate(
diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py
index 5b19e6faa4..36bb5802a9 100644
--- a/redis/commands/search/result.py
+++ b/redis/commands/search/result.py
@@ -9,7 +9,13 @@ class Result:
     """
 
     def __init__(
-        self, res, hascontent, duration=0, has_payload=False, with_scores=False
+        self,
+        res,
+        hascontent,
+        duration=0,
+        has_payload=False,
+        with_scores=False,
+        decode_fields=False,
     ):
         """
         - **snippets**: An optional dictionary of the form
@@ -32,24 +38,26 @@ def __init__(
 
         for i in range(1, len(res), step):
             id = to_string(res[i])
-            payload = to_string(res[i + offset]) if has_payload else None
+            if has_payload:
+                payload_data = res[i + offset]
+                payload = to_string(payload_data) if decode_fields else payload_data
+            else:
+                payload = None
             # fields_offset = 2 if has_payload else 1
             fields_offset = offset + 1 if has_payload else offset
             score = float(res[i + 1]) if with_scores else None
 
             fields = {}
             if hascontent and res[i + fields_offset] is not None:
-                fields = (
+                keys = res[i + fields_offset][::2]
+                values = res[i + fields_offset][1::2]
+                fields = dict(
                     dict(
-                        dict(
-                            zip(
-                                map(to_string, res[i + fields_offset][::2]),
-                                map(to_string, res[i + fields_offset][1::2]),
-                            )
+                        zip(
+                            map(to_string, keys),
+                            map(to_string, values) if decode_fields else values,
                         )
                     )
-                    if hascontent
-                    else {}
                 )
             try:
                 del fields["id"]
diff --git a/tests/test_search.py b/tests/test_search.py
index bfe204254c..694021afcf 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -4,6 +4,7 @@
 import time
 from io import TextIOWrapper
 
+import numpy as np
 import pytest
 import redis
 import redis.commands.search
@@ -2284,3 +2285,37 @@ def test_geoshape(client: redis.Redis):
     assert result.docs[0]["id"] == "small"
     result = client.ft().search(q2, query_params=qp2)
     assert len(result.docs) == 2
+
+
+@pytest.mark.redismod
+def test_vector_storage_and_retrieval(r: redis.Redis):
+    r.ft("vector_index").create_index(
+        (
+            VectorField(
+                "my_vector",
+                "FLAT",
+                {
+                    "TYPE": "FLOAT32",
+                    "DIM": 4,
+                    "DISTANCE_METRIC": "COSINE",
+                },
+            ),
+        ),
+        definition=IndexDefinition(prefix=["doc:"], index_type=IndexType.HASH),
+    )
+
+    vector_data = [0.1, 0.2, 0.3, 0.4]
+    r.hset(
+        f"doc:1",
+        mapping={"my_vector": np.array(vector_data, dtype=np.float32).tobytes()},
+    )
+
+    query = Query("*").with_payloads().return_fields("my_vector").dialect(2)
+    res = r.ft("vector_index").search(query, decode_fields=False)
+
+    assert res.total == 1
+    assert res.docs[0].id == f"doc:1"
+    retrieved_vector_data = np.frombuffer(
+        res.docs[0].__dict__["my_vector"], dtype=np.float32
+    )
+    assert np.allclose(retrieved_vector_data, vector_data)

From 76688034a2b6bb5f25db4b199ff8cc89526b309a Mon Sep 17 00:00:00 2001
From: Gabriel Erzse <gabriel.erzse@redis.com>
Date: Wed, 29 May 2024 19:12:34 +0300
Subject: [PATCH 2/3] Use a NumPy that supports Python 3.7

---
 dev_requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/dev_requirements.txt b/dev_requirements.txt
index f394314db2..7daefb8ca6 100644
--- a/dev_requirements.txt
+++ b/dev_requirements.txt
@@ -4,7 +4,7 @@ flake8==5.0.4
 flake8-isort==6.0.0
 flynt~=0.69.0
 mock==4.0.3
-numpy>=1.24.4
+numpy>=1.21.0
 packaging>=20.4
 pytest==7.2.0
 pytest-timeout==2.1.0

From c64692bf3b39df75847b206d1966b2b85d69f8ea Mon Sep 17 00:00:00 2001
From: Gabriel Erzse <gabriel.erzse@redis.com>
Date: Wed, 29 May 2024 19:35:33 +0300
Subject: [PATCH 3/3] Fix linter errors

---
 tests/test_search.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/test_search.py b/tests/test_search.py
index 694021afcf..f74975a914 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -2306,7 +2306,7 @@ def test_vector_storage_and_retrieval(r: redis.Redis):
 
     vector_data = [0.1, 0.2, 0.3, 0.4]
     r.hset(
-        f"doc:1",
+        "doc:1",
         mapping={"my_vector": np.array(vector_data, dtype=np.float32).tobytes()},
     )
 
@@ -2314,7 +2314,7 @@ def test_vector_storage_and_retrieval(r: redis.Redis):
     res = r.ft("vector_index").search(query, decode_fields=False)
 
     assert res.total == 1
-    assert res.docs[0].id == f"doc:1"
+    assert res.docs[0].id == "doc:1"
     retrieved_vector_data = np.frombuffer(
         res.docs[0].__dict__["my_vector"], dtype=np.float32
     )