Skip to content

Commit 2717a0e

Browse files
uglidevladvildanov
authored andcommitted
Decode search results at field level (#3309)
Make it possible to configure at field level how search results are decoded. Fixes: #2772, #2275
1 parent 826035f commit 2717a0e

File tree

8 files changed

+178
-28
lines changed

8 files changed

+178
-28
lines changed

dev_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ urllib3<2
1616
uvloop
1717
vulture>=2.3.0
1818
wheel>=0.30.0
19+
numpy>=1.24.0

docs/examples/search_vector_similarity_examples.ipynb

+3-2
Large diffs are not rendered by default.

redis/commands/search/_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
def to_string(s):
1+
def to_string(s, encoding: str = "utf-8"):
22
if isinstance(s, str):
33
return s
44
elif isinstance(s, bytes):
5-
return s.decode("utf-8", "ignore")
5+
return s.decode(encoding, "ignore")
66
else:
77
return s # Not a string we care about

redis/commands/search/commands.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from typing import Dict, List, Optional, Union
44

5-
from redis.client import Pipeline
5+
from redis.client import NEVER_DECODE, Pipeline
66
from redis.utils import deprecated_function
77

88
from ..helpers import get_protocol_version, parse_to_dict
@@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
8282
duration=kwargs["duration"],
8383
has_payload=kwargs["query"]._with_payloads,
8484
with_scores=kwargs["query"]._with_scores,
85+
field_encodings=kwargs["query"]._return_fields_decode_as,
8586
)
8687

8788
def _parse_aggregate(self, res, **kwargs):
@@ -499,7 +500,12 @@ def search(
499500
""" # noqa
500501
args, query = self._mk_query_args(query, query_params=query_params)
501502
st = time.time()
502-
res = self.execute_command(SEARCH_CMD, *args)
503+
504+
options = {}
505+
if get_protocol_version(self.client) not in ["3", 3]:
506+
options[NEVER_DECODE] = True
507+
508+
res = self.execute_command(SEARCH_CMD, *args, **options)
503509

504510
if isinstance(res, Pipeline):
505511
return res
@@ -926,7 +932,12 @@ async def search(
926932
""" # noqa
927933
args, query = self._mk_query_args(query, query_params=query_params)
928934
st = time.time()
929-
res = await self.execute_command(SEARCH_CMD, *args)
935+
936+
options = {}
937+
if get_protocol_version(self.client) not in ["3", 3]:
938+
options[NEVER_DECODE] = True
939+
940+
res = await self.execute_command(SEARCH_CMD, *args, **options)
930941

931942
if isinstance(res, Pipeline):
932943
return res

redis/commands/search/query.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
3535
self._in_order: bool = False
3636
self._sortby: Optional[SortbyField] = None
3737
self._return_fields: List = []
38+
self._return_fields_decode_as: dict = {}
3839
self._summarize_fields: List = []
3940
self._highlight_fields: List = []
4041
self._language: Optional[str] = None
@@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":
5354

5455
def return_fields(self, *fields) -> "Query":
5556
"""Add fields to return fields."""
56-
self._return_fields += fields
57+
for field in fields:
58+
self.return_field(field)
5759
return self
5860

59-
def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
60-
"""Add field to return fields (Optional: add 'AS' name
61-
to the field)."""
61+
def return_field(
62+
self,
63+
field: str,
64+
as_field: Optional[str] = None,
65+
decode_field: Optional[bool] = True,
66+
encoding: Optional[str] = "utf8",
67+
) -> "Query":
68+
"""
69+
Add a field to the list of fields to return.
70+
71+
- **field**: The field to include in query results
72+
- **as_field**: The alias for the field
73+
- **decode_field**: Whether to decode the field from bytes to string
74+
- **encoding**: The encoding to use when decoding the field
75+
"""
6276
self._return_fields.append(field)
77+
self._return_fields_decode_as[field] = encoding if decode_field else None
6378
if as_field is not None:
6479
self._return_fields += ("AS", as_field)
6580
return self

redis/commands/search/result.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from ._util import to_string
24
from .document import Document
35

@@ -9,11 +11,19 @@ class Result:
911
"""
1012

1113
def __init__(
12-
self, res, hascontent, duration=0, has_payload=False, with_scores=False
14+
self,
15+
res,
16+
hascontent,
17+
duration=0,
18+
has_payload=False,
19+
with_scores=False,
20+
field_encodings: Optional[dict] = None,
1321
):
1422
"""
15-
- **snippets**: An optional dictionary of the form
16-
{field: snippet_size} for snippet formatting
23+
- duration: the execution time of the query
24+
- has_payload: whether the query has payloads
25+
- with_scores: whether the query has scores
26+
- field_encodings: a dictionary of field encodings if any is provided
1727
"""
1828

1929
self.total = res[0]
@@ -39,18 +49,22 @@ def __init__(
3949

4050
fields = {}
4151
if hascontent and res[i + fields_offset] is not None:
42-
fields = (
43-
dict(
44-
dict(
45-
zip(
46-
map(to_string, res[i + fields_offset][::2]),
47-
map(to_string, res[i + fields_offset][1::2]),
48-
)
49-
)
50-
)
51-
if hascontent
52-
else {}
53-
)
52+
keys = map(to_string, res[i + fields_offset][::2])
53+
values = res[i + fields_offset][1::2]
54+
55+
for key, value in zip(keys, values):
56+
if field_encodings is None or key not in field_encodings:
57+
fields[key] = to_string(value)
58+
continue
59+
60+
encoding = field_encodings[key]
61+
62+
# If the encoding is None, we don't need to decode the value
63+
if encoding is None:
64+
fields[key] = value
65+
else:
66+
fields[key] = to_string(value, encoding=encoding)
67+
5468
try:
5569
del fields["id"]
5670
except KeyError:

tests/test_asyncio/test_search.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,30 @@
44
import time
55
from io import TextIOWrapper
66

7+
import numpy as np
78
import pytest
89
import pytest_asyncio
910
import redis.asyncio as redis
1011
import redis.commands.search
1112
import redis.commands.search.aggregation as aggregations
1213
import redis.commands.search.reducers as reducers
1314
from redis.commands.search import AsyncSearch
14-
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
15-
from redis.commands.search.indexDefinition import IndexDefinition
15+
from redis.commands.search.field import (
16+
GeoField,
17+
NumericField,
18+
TagField,
19+
TextField,
20+
VectorField,
21+
)
22+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
1623
from redis.commands.search.query import GeoFilter, NumericFilter, Query
1724
from redis.commands.search.result import Result
1825
from redis.commands.search.suggestion import Suggestion
1926
from tests.conftest import (
2027
assert_resp_response,
2128
is_resp2_connection,
2229
skip_if_redis_enterprise,
30+
skip_if_resp_version,
2331
skip_ifmodversion_lt,
2432
)
2533

@@ -1560,3 +1568,53 @@ async def test_query_timeout(decoded_r: redis.Redis):
15601568
q2 = Query("foo").timeout("not_a_number")
15611569
with pytest.raises(redis.ResponseError):
15621570
await decoded_r.ft().search(q2)
1571+
1572+
1573+
@pytest.mark.redismod
1574+
@skip_if_resp_version(3)
1575+
async def test_binary_and_text_fields(decoded_r: redis.Redis):
1576+
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
1577+
1578+
index_name = "mixed_index"
1579+
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1580+
await decoded_r.hset(f"{index_name}:1", mapping=mixed_data)
1581+
1582+
schema = (
1583+
TagField("first_name"),
1584+
VectorField(
1585+
"embeddings_bio",
1586+
algorithm="HNSW",
1587+
attributes={
1588+
"TYPE": "FLOAT32",
1589+
"DIM": 4,
1590+
"DISTANCE_METRIC": "COSINE",
1591+
},
1592+
),
1593+
)
1594+
1595+
await decoded_r.ft(index_name).create_index(
1596+
fields=schema,
1597+
definition=IndexDefinition(
1598+
prefix=[f"{index_name}:"], index_type=IndexType.HASH
1599+
),
1600+
)
1601+
1602+
query = (
1603+
Query("*")
1604+
.return_field("vector_emb", decode_field=False)
1605+
.return_field("first_name")
1606+
)
1607+
result = await decoded_r.ft(index_name).search(query=query, query_params={})
1608+
docs = result.docs
1609+
1610+
decoded_vec_from_search_results = np.frombuffer(
1611+
docs[0]["vector_emb"], dtype=np.float32
1612+
)
1613+
1614+
assert np.array_equal(
1615+
decoded_vec_from_search_results, fake_vec
1616+
), "The vectors are not equal"
1617+
1618+
assert (
1619+
docs[0]["first_name"] == mixed_data["first_name"]
1620+
), "The text field is not decoded correctly"

tests/test_search.py

+50
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from io import TextIOWrapper
66

7+
import numpy as np
78
import pytest
89
import redis
910
import redis.commands.search
@@ -29,6 +30,7 @@
2930
assert_resp_response,
3031
is_resp2_connection,
3132
skip_if_redis_enterprise,
33+
skip_if_resp_version,
3234
skip_ifmodversion_lt,
3335
)
3436

@@ -1705,6 +1707,54 @@ def test_search_return_fields(client):
17051707
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]
17061708

17071709

1710+
@pytest.mark.redismod
1711+
@skip_if_resp_version(3)
1712+
def test_binary_and_text_fields(client):
1713+
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
1714+
1715+
index_name = "mixed_index"
1716+
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1717+
client.hset(f"{index_name}:1", mapping=mixed_data)
1718+
1719+
schema = (
1720+
TagField("first_name"),
1721+
VectorField(
1722+
"embeddings_bio",
1723+
algorithm="HNSW",
1724+
attributes={
1725+
"TYPE": "FLOAT32",
1726+
"DIM": 4,
1727+
"DISTANCE_METRIC": "COSINE",
1728+
},
1729+
),
1730+
)
1731+
1732+
client.ft(index_name).create_index(
1733+
fields=schema,
1734+
definition=IndexDefinition(
1735+
prefix=[f"{index_name}:"], index_type=IndexType.HASH
1736+
),
1737+
)
1738+
1739+
query = (
1740+
Query("*")
1741+
.return_field("vector_emb", decode_field=False)
1742+
.return_field("first_name")
1743+
)
1744+
docs = client.ft(index_name).search(query=query, query_params={}).docs
1745+
decoded_vec_from_search_results = np.frombuffer(
1746+
docs[0]["vector_emb"], dtype=np.float32
1747+
)
1748+
1749+
assert np.array_equal(
1750+
decoded_vec_from_search_results, fake_vec
1751+
), "The vectors are not equal"
1752+
1753+
assert (
1754+
docs[0]["first_name"] == mixed_data["first_name"]
1755+
), "The text field is not decoded correctly"
1756+
1757+
17081758
@pytest.mark.redismod
17091759
def test_synupdate(client):
17101760
definition = IndexDefinition(index_type=IndexType.HASH)

0 commit comments

Comments
 (0)