Skip to content

Commit fd959fd

Browse files
committed
change to allow metadata fields to be pk
Signed-off-by: Theo van Kraay <theo.van@microsoft.com>
1 parent 971a58c commit fd959fd

File tree

2 files changed

+294
-7
lines changed

2 files changed

+294
-7
lines changed

vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java

+70-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.Collections;
2121
import java.util.HashMap;
22+
import java.util.Iterator;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.Optional;
@@ -32,6 +33,7 @@
3233
import com.azure.cosmos.models.CosmosBulkOperations;
3334
import com.azure.cosmos.models.CosmosContainerProperties;
3435
import com.azure.cosmos.models.CosmosItemOperation;
36+
import com.azure.cosmos.models.CosmosItemResponse;
3537
import com.azure.cosmos.models.CosmosQueryRequestOptions;
3638
import com.azure.cosmos.models.CosmosVectorDataType;
3739
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
@@ -40,6 +42,7 @@
4042
import com.azure.cosmos.models.CosmosVectorIndexSpec;
4143
import com.azure.cosmos.models.CosmosVectorIndexType;
4244
import com.azure.cosmos.models.ExcludedPath;
45+
import com.azure.cosmos.models.FeedResponse;
4346
import com.azure.cosmos.models.IncludedPath;
4447
import com.azure.cosmos.models.IndexingMode;
4548
import com.azure.cosmos.models.IndexingPolicy;
@@ -235,10 +238,30 @@ public void doAdd(List<Document> documents) {
235238
// Create a list to hold both the CosmosItemOperation and the corresponding
236239
// document ID
237240
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
241+
String partitionKeyValue;
242+
243+
if ("/id".equals(this.partitionKeyPath)) {
244+
partitionKeyValue = doc.getId();
245+
}
246+
else if (this.partitionKeyPath.startsWith("/metadata/")) {
247+
// Extract the key, e.g. "/metadata/country" -> "country"
248+
String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
249+
Object value = doc.getMetadata() != null ? doc.getMetadata().get(metadataKey) : null;
250+
if (value == null) {
251+
throw new IllegalArgumentException(
252+
"Partition key '" + metadataKey + "' not found in document metadata.");
253+
}
254+
partitionKeyValue = value.toString();
255+
}
256+
else {
257+
throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
258+
}
259+
238260
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
239-
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
240-
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
261+
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))),
262+
new PartitionKey(partitionKeyValue)); // Pair the document ID
241263
// with the operation
264+
return new ImmutablePair<>(doc.getId(), operation);
242265
}).toList();
243266

244267
try {
@@ -285,20 +308,60 @@ public void doAdd(List<Document> documents) {
285308
public void doDelete(List<String> idList) {
286309
try {
287310
// Convert the list of IDs into bulk delete operations
288-
List<CosmosItemOperation> itemOperations = idList.stream()
289-
.map(id -> CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(id)))
290-
.collect(Collectors.toList());
311+
List<CosmosItemOperation> itemOperations = idList.stream().map(id -> {
312+
String partitionKeyValue;
313+
314+
if ("/id".equals(this.partitionKeyPath)) {
315+
partitionKeyValue = id;
316+
}
317+
318+
else if (this.partitionKeyPath.startsWith("/metadata/")) {
319+
// Will be inefficient for large numbers of documents but there is no
320+
// other way to get the partition key value
321+
// with current method signature. Ideally, we should be able to pass
322+
// the partition key value directly.
323+
String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
324+
325+
// Run a reactive query to fetch the document by ID
326+
String query = String.format("SELECT * FROM c WHERE c.id = '%s'", id);
327+
CosmosPagedFlux<JsonNode> queryFlux = this.container.queryItems(query,
328+
new CosmosQueryRequestOptions(), JsonNode.class);
329+
330+
// Block to retrieve the first page synchronously
331+
List<JsonNode> documents = queryFlux.byPage(1).blockFirst().getResults();
332+
333+
if (documents == null || documents.isEmpty()) {
334+
throw new IllegalArgumentException("No document found for id: " + id);
335+
}
336+
337+
JsonNode document = documents.get(0);
338+
JsonNode metadataNode = document.get("metadata");
339+
340+
if (metadataNode == null || metadataNode.get(metadataKey) == null) {
341+
throw new IllegalArgumentException("Partition key '" + metadataKey
342+
+ "' not found in metadata for document with id: " + id);
343+
}
344+
345+
partitionKeyValue = metadataNode.get(metadataKey).asText();
346+
}
347+
else {
348+
throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
349+
}
350+
351+
return CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(partitionKeyValue));
352+
}).collect(Collectors.toList());
291353

292354
// Execute bulk delete operations synchronously by using blockLast() on the
293355
// Flux
294356
this.container.executeBulkOperations(Flux.fromIterable(itemOperations))
295357
.doOnNext(response -> logger.info("Document deleted with status: {}",
296358
response.getResponse().getStatusCode()))
297359
.doOnError(error -> logger.error("Error deleting document: {}", error.getMessage()))
298-
.blockLast(); // This will block until all operations have finished
360+
.blockLast();
299361
}
300362
catch (Exception e) {
301-
logger.error("Exception while deleting documents: {}", e.getMessage());
363+
logger.error("Exception while deleting documents: {}", e.getMessage(), e);
364+
throw e;
302365
}
303366
}
304367

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore.cosmosdb;
18+
19+
import com.azure.cosmos.CosmosAsyncClient;
20+
import com.azure.cosmos.CosmosAsyncContainer;
21+
import com.azure.cosmos.CosmosClientBuilder;
22+
import com.azure.identity.DefaultAzureCredentialBuilder;
23+
import org.junit.jupiter.api.BeforeEach;
24+
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
26+
import org.springframework.ai.document.Document;
27+
import org.springframework.ai.embedding.EmbeddingModel;
28+
import org.springframework.ai.transformers.TransformersEmbeddingModel;
29+
import org.springframework.ai.vectorstore.SearchRequest;
30+
import org.springframework.ai.vectorstore.VectorStore;
31+
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
32+
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
33+
import org.springframework.boot.SpringBootConfiguration;
34+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
35+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
36+
import org.springframework.context.annotation.Bean;
37+
38+
import java.util.HashMap;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.Optional;
42+
import java.util.UUID;
43+
44+
import static org.assertj.core.api.Assertions.assertThat;
45+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
46+
47+
/**
48+
* @author Theo van Kraay
49+
* @author Thomas Vitale
50+
* @since 1.0.0
51+
*/
52+
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
53+
public class CosmosDBVectorStoreWithMetadataPartitionKeyIT {
54+
55+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
56+
.withUserConfiguration(TestApplication.class);
57+
58+
private VectorStore vectorStore;
59+
60+
@BeforeEach
61+
public void setup() {
62+
this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class));
63+
}
64+
65+
@Test
66+
public void testAddSearchAndDeleteDocuments() {
67+
68+
// Create a sample document
69+
Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1"));
70+
assertThatThrownBy(() -> this.vectorStore.add(List.of(document1))).isInstanceOf(Exception.class)
71+
.hasMessageContaining("Partition key 'country' not found in document metadata.");
72+
73+
Document document2 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("country", "UK"));
74+
this.vectorStore.add(List.of(document2));
75+
76+
// Perform a similarity search
77+
List<Document> results = this.vectorStore
78+
.similaritySearch(SearchRequest.builder().query("Sample content1").topK(1).build());
79+
80+
// Verify the search results
81+
assertThat(results).isNotEmpty();
82+
assertThat(results.get(0).getId()).isEqualTo(document2.getId());
83+
84+
// Remove the documents from the vector store
85+
this.vectorStore.delete(List.of(document2.getId()));
86+
87+
// Perform a similarity search again
88+
List<Document> results2 = this.vectorStore
89+
.similaritySearch(SearchRequest.builder().query("Sample content").topK(1).build());
90+
91+
// Verify the search results
92+
assertThat(results2).isEmpty();
93+
94+
}
95+
96+
@Test
97+
void testSimilaritySearchWithFilter() {
98+
99+
// Insert documents using vectorStore.add
100+
Map<String, Object> metadata1;
101+
metadata1 = new HashMap<>();
102+
metadata1.put("country", "UK");
103+
metadata1.put("year", 2021);
104+
metadata1.put("city", "London");
105+
106+
Map<String, Object> metadata2;
107+
metadata2 = new HashMap<>();
108+
metadata2.put("country", "NL");
109+
metadata2.put("year", 2022);
110+
metadata2.put("city", "Amsterdam");
111+
112+
Map<String, Object> metadata3;
113+
metadata3 = new HashMap<>();
114+
metadata3.put("country", "US");
115+
metadata3.put("year", 2019);
116+
metadata3.put("city", "Sofia");
117+
118+
Map<String, Object> metadata4;
119+
metadata4 = new HashMap<>();
120+
metadata4.put("country", "US");
121+
metadata4.put("year", 2020);
122+
metadata4.put("city", "Sofia");
123+
124+
Document document1 = new Document("1", "A document about the UK", metadata1);
125+
Document document2 = new Document("2", "A document about the Netherlands", metadata2);
126+
Document document3 = new Document("3", "A document about the US", metadata3);
127+
Document document4 = new Document("4", "A document about the US", metadata4);
128+
129+
this.vectorStore.add(List.of(document1, document2, document3, document4));
130+
FilterExpressionBuilder b = new FilterExpressionBuilder();
131+
List<Document> results = this.vectorStore.similaritySearch(SearchRequest.builder()
132+
.query("The World")
133+
.topK(10)
134+
.filterExpression((b.in("country", "UK", "NL")).build())
135+
.build());
136+
137+
assertThat(results).hasSize(2);
138+
assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
139+
for (Document doc : results) {
140+
assertThat(doc.getMetadata().get("country")).isIn("UK", "NL");
141+
assertThat(doc.getMetadata().get("year")).isIn(2021, 2022);
142+
assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia");
143+
}
144+
145+
List<Document> results2 = this.vectorStore.similaritySearch(SearchRequest.builder()
146+
.query("The World")
147+
.topK(10)
148+
.filterExpression(
149+
b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())
150+
.build());
151+
152+
assertThat(results2).hasSize(1);
153+
assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1");
154+
155+
List<Document> results3 = this.vectorStore.similaritySearch(SearchRequest.builder()
156+
.query("The World")
157+
.topK(10)
158+
.filterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())
159+
.build());
160+
161+
assertThat(results3).hasSize(1);
162+
assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4");
163+
164+
this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId()));
165+
166+
// Perform a similarity search again
167+
List<Document> results4 = this.vectorStore
168+
.similaritySearch(SearchRequest.builder().query("The World").topK(1).build());
169+
170+
// Verify the search results
171+
assertThat(results4).isEmpty();
172+
}
173+
174+
@Test
175+
void getNativeClientTest() {
176+
this.contextRunner.run(context -> {
177+
CosmosDBVectorStore vectorStore = context.getBean(CosmosDBVectorStore.class);
178+
Optional<CosmosAsyncContainer> nativeClient = vectorStore.getNativeClient();
179+
assertThat(nativeClient).isPresent();
180+
});
181+
}
182+
183+
@SpringBootConfiguration
184+
@EnableAutoConfiguration
185+
public static class TestApplication {
186+
187+
@Bean
188+
public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel,
189+
VectorStoreObservationConvention convention) {
190+
return CosmosDBVectorStore.builder(cosmosClient, embeddingModel)
191+
.databaseName("test-database")
192+
.containerName("test-container-metadata-partition-key")
193+
.metadataFields(List.of("country", "year", "city"))
194+
.partitionKeyPath("/metadata/country")
195+
.vectorStoreThroughput(1000)
196+
.customObservationConvention(convention)
197+
.build();
198+
}
199+
200+
@Bean
201+
public CosmosAsyncClient cosmosClient() {
202+
return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT"))
203+
.credential(new DefaultAzureCredentialBuilder().build())
204+
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
205+
.gatewayMode()
206+
.buildAsyncClient();
207+
}
208+
209+
@Bean
210+
public EmbeddingModel embeddingModel() {
211+
return new TransformersEmbeddingModel();
212+
}
213+
214+
@Bean
215+
public VectorStoreObservationConvention observationConvention() {
216+
// Replace with an actual observation convention or a mock if needed
217+
return new VectorStoreObservationConvention() {
218+
219+
};
220+
}
221+
222+
}
223+
224+
}

0 commit comments

Comments
 (0)