Skip to content

Commit 7ca2f29

Browse files
committedJul 9, 2024
Decode search results at field level
Fixes: #2772, #2275
1 parent 0be67bf commit 7ca2f29

File tree

7 files changed

+118
-23
lines changed

7 files changed

+118
-23
lines changed
 

‎dev_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ urllib3<2
1616
uvloop
1717
vulture>=2.3.0
1818
wheel>=0.30.0
19+
numpy>=1.24.0

‎docs/examples/search_vector_similarity_examples.ipynb

+3-2
Large diffs are not rendered by default.

‎redis/commands/search/_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
def to_string(s):
1+
def to_string(s, encoding: str = "utf-8"):
22
if isinstance(s, str):
33
return s
44
elif isinstance(s, bytes):
5-
return s.decode("utf-8", "ignore")
5+
return s.decode(encoding, "ignore")
66
else:
77
return s # Not a string we care about

‎redis/commands/search/commands.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
8282
duration=kwargs["duration"],
8383
has_payload=kwargs["query"]._with_payloads,
8484
with_scores=kwargs["query"]._with_scores,
85+
field_encodings=kwargs["query"]._return_fields_decode_as,
8586
)
8687

8788
def _parse_aggregate(self, res, **kwargs):

‎redis/commands/search/query.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
3535
self._in_order: bool = False
3636
self._sortby: Optional[SortbyField] = None
3737
self._return_fields: List = []
38+
self._return_fields_decode_as: dict = {}
3839
self._summarize_fields: List = []
3940
self._highlight_fields: List = []
4041
self._language: Optional[str] = None
@@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":
5354

5455
def return_fields(self, *fields) -> "Query":
5556
"""Add fields to return fields."""
56-
self._return_fields += fields
57+
for field in fields:
58+
self.return_field(field)
5759
return self
5860

59-
def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
60-
"""Add field to return fields (Optional: add 'AS' name
61-
to the field)."""
61+
def return_field(
62+
self,
63+
field: str,
64+
as_field: Optional[str] = None,
65+
decode_field: Optional[bool] = True,
66+
encoding: Optional[str] = "utf8",
67+
) -> "Query":
68+
"""
69+
Add a field to the list of fields to return.
70+
71+
- **field**: The field to include in query results
72+
- **as_field**: The alias for the field
73+
- **decode_field**: Whether to decode the field from bytes to string
74+
- **encoding**: The encoding to use when decoding the field
75+
"""
6276
self._return_fields.append(field)
77+
self._return_fields_decode_as[field] = encoding if decode_field else None
6378
if as_field is not None:
6479
self._return_fields += ("AS", as_field)
6580
return self

‎redis/commands/search/result.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from ._util import to_string
24
from .document import Document
35

@@ -9,11 +11,19 @@ class Result:
911
"""
1012

1113
def __init__(
12-
self, res, hascontent, duration=0, has_payload=False, with_scores=False
14+
self,
15+
res,
16+
hascontent,
17+
duration=0,
18+
has_payload=False,
19+
with_scores=False,
20+
field_encodings: Optional[dict] = None,
1321
):
1422
"""
15-
- **snippets**: An optional dictionary of the form
16-
{field: snippet_size} for snippet formatting
23+
- duration: the execution time of the query
24+
- has_payload: whether the query has payloads
25+
- with_scores: whether the query has scores
26+
- field_encodings: a dictionary of field encodings if any is provided
1727
"""
1828

1929
self.total = res[0]
@@ -39,18 +49,22 @@ def __init__(
3949

4050
fields = {}
4151
if hascontent and res[i + fields_offset] is not None:
42-
fields = (
43-
dict(
44-
dict(
45-
zip(
46-
map(to_string, res[i + fields_offset][::2]),
47-
map(to_string, res[i + fields_offset][1::2]),
48-
)
49-
)
50-
)
51-
if hascontent
52-
else {}
53-
)
52+
keys = map(to_string, res[i + fields_offset][::2])
53+
values = res[i + fields_offset][1::2]
54+
55+
for key, value in zip(keys, values):
56+
if field_encodings is None or key not in field_encodings:
57+
fields[key] = to_string(value)
58+
continue
59+
60+
encoding = field_encodings[key]
61+
62+
# If the encoding is None, we don't need to decode the value
63+
if encoding is None:
64+
fields[key] = value
65+
else:
66+
fields[key] = to_string(value, encoding=encoding)
67+
5468
try:
5569
del fields["id"]
5670
except KeyError:

‎tests/test_search.py

+63
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from io import TextIOWrapper
66

7+
import numpy as np
78
import pytest
89
import redis
910
import redis.commands.search
@@ -113,6 +114,13 @@ def client(request, stack_url):
113114
return r
114115

115116

117+
@pytest.fixture
118+
def binary_client(request, stack_url):
119+
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
120+
r.flushdb()
121+
return r
122+
123+
116124
@pytest.mark.redismod
117125
def test_client(client):
118126
num_docs = 500
@@ -1705,6 +1713,61 @@ def test_search_return_fields(client):
17051713
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]
17061714

17071715

1716+
@pytest.mark.redismod
1717+
def test_binary_and_text_fields(binary_client):
1718+
assert (
1719+
binary_client.get_connection_kwargs()["decode_responses"] is False
1720+
), "This feature is only available when decode_responses is False"
1721+
1722+
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
1723+
1724+
index_name = "mixed_index"
1725+
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1726+
binary_client.hset(f"{index_name}:1", mapping=mixed_data)
1727+
1728+
schema = (
1729+
TagField("first_name"),
1730+
VectorField(
1731+
"embeddings_bio",
1732+
algorithm="HNSW",
1733+
attributes={
1734+
"TYPE": "FLOAT32",
1735+
"DIM": 4,
1736+
"DISTANCE_METRIC": "COSINE",
1737+
},
1738+
),
1739+
)
1740+
1741+
binary_client.ft(index_name).create_index(
1742+
fields=schema,
1743+
definition=IndexDefinition(
1744+
prefix=[f"{index_name}:"], index_type=IndexType.HASH
1745+
),
1746+
)
1747+
1748+
bytes_person_1 = binary_client.hget(f"{index_name}:1", "vector_emb")
1749+
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
1750+
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"
1751+
1752+
query = (
1753+
Query("*")
1754+
.return_field("vector_emb", decode_field=False)
1755+
.return_field("first_name", decode_field=True)
1756+
)
1757+
docs = binary_client.ft(index_name).search(query=query, query_params={}).docs
1758+
decoded_vec_from_search_results = np.frombuffer(
1759+
docs[0]["vector_emb"], dtype=np.float32
1760+
)
1761+
1762+
assert np.array_equal(
1763+
decoded_vec_from_search_results, fake_vec
1764+
), "The vectors are not equal"
1765+
1766+
assert (
1767+
docs[0]["first_name"] == mixed_data["first_name"]
1768+
), "The first is not decoded correctly"
1769+
1770+
17081771
@pytest.mark.redismod
17091772
def test_synupdate(client):
17101773
definition = IndexDefinition(index_type=IndexType.HASH)

0 commit comments

Comments
 (0)