|
4 | 4 | import time
|
5 | 5 | from io import TextIOWrapper
|
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | import pytest
|
8 | 9 | import pytest_asyncio
|
9 | 10 | import redis.asyncio as redis
|
10 | 11 | import redis.commands.search
|
11 | 12 | import redis.commands.search.aggregation as aggregations
|
12 | 13 | import redis.commands.search.reducers as reducers
|
13 | 14 | from redis.commands.search import AsyncSearch
|
14 |
| -from redis.commands.search.field import GeoField, NumericField, TagField, TextField |
15 |
| -from redis.commands.search.indexDefinition import IndexDefinition |
| 15 | +from redis.commands.search.field import ( |
| 16 | + GeoField, |
| 17 | + NumericField, |
| 18 | + TagField, |
| 19 | + TextField, |
| 20 | + VectorField, |
| 21 | +) |
| 22 | +from redis.commands.search.indexDefinition import IndexDefinition, IndexType |
16 | 23 | from redis.commands.search.query import GeoFilter, NumericFilter, Query
|
17 | 24 | from redis.commands.search.result import Result
|
18 | 25 | from redis.commands.search.suggestion import Suggestion
|
19 | 26 | from tests.conftest import (
|
20 | 27 | assert_resp_response,
|
21 | 28 | is_resp2_connection,
|
22 | 29 | skip_if_redis_enterprise,
|
| 30 | + skip_if_resp_version, |
23 | 31 | skip_ifmodversion_lt,
|
24 | 32 | )
|
25 | 33 |
|
@@ -1560,3 +1568,53 @@ async def test_query_timeout(decoded_r: redis.Redis):
|
1560 | 1568 | q2 = Query("foo").timeout("not_a_number")
|
1561 | 1569 | with pytest.raises(redis.ResponseError):
|
1562 | 1570 | await decoded_r.ft().search(q2)
|
| 1571 | + |
| 1572 | + |
| 1573 | +@pytest.mark.redismod |
| 1574 | +@skip_if_resp_version(3) |
| 1575 | +async def test_binary_and_text_fields(decoded_r: redis.Redis): |
| 1576 | + fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) |
| 1577 | + |
| 1578 | + index_name = "mixed_index" |
| 1579 | + mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} |
| 1580 | + await decoded_r.hset(f"{index_name}:1", mapping=mixed_data) |
| 1581 | + |
| 1582 | + schema = ( |
| 1583 | + TagField("first_name"), |
| 1584 | + VectorField( |
| 1585 | + "embeddings_bio", |
| 1586 | + algorithm="HNSW", |
| 1587 | + attributes={ |
| 1588 | + "TYPE": "FLOAT32", |
| 1589 | + "DIM": 4, |
| 1590 | + "DISTANCE_METRIC": "COSINE", |
| 1591 | + }, |
| 1592 | + ), |
| 1593 | + ) |
| 1594 | + |
| 1595 | + await decoded_r.ft(index_name).create_index( |
| 1596 | + fields=schema, |
| 1597 | + definition=IndexDefinition( |
| 1598 | + prefix=[f"{index_name}:"], index_type=IndexType.HASH |
| 1599 | + ), |
| 1600 | + ) |
| 1601 | + |
| 1602 | + query = ( |
| 1603 | + Query("*") |
| 1604 | + .return_field("vector_emb", decode_field=False) |
| 1605 | + .return_field("first_name") |
| 1606 | + ) |
| 1607 | + result = await decoded_r.ft(index_name).search(query=query, query_params={}) |
| 1608 | + docs = result.docs |
| 1609 | + |
| 1610 | + decoded_vec_from_search_results = np.frombuffer( |
| 1611 | + docs[0]["vector_emb"], dtype=np.float32 |
| 1612 | + ) |
| 1613 | + |
| 1614 | + assert np.array_equal( |
| 1615 | + decoded_vec_from_search_results, fake_vec |
| 1616 | + ), "The vectors are not equal" |
| 1617 | + |
| 1618 | + assert ( |
| 1619 | + docs[0]["first_name"] == mixed_data["first_name"] |
| 1620 | + ), "The text field is not decoded correctly" |
0 commit comments