diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml index 61980a98f46..067b8ccb507 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml @@ -45,6 +45,11 @@ ${project.parent.version} true + + com.azure + azure-identity + 1.15.4 + org.springframework.boot spring-boot-starter diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java index 5fdc1686b33..bbf3500802b 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java @@ -18,6 +18,7 @@ import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosClientBuilder; +import com.azure.identity.DefaultAzureCredentialBuilder; import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.BatchingStrategy; @@ -33,6 +34,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; + import java.util.List; /** @@ -42,6 +44,7 @@ * @author Soby Chacko * @since 1.0.0 */ + @AutoConfiguration @ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class }) @EnableConfigurationProperties(CosmosDBVectorStoreProperties.class) @@ -49,16 +52,29 @@ matchIfMissing = true) public class CosmosDBVectorStoreAutoConfiguration { - String endpoint; - - String key; + private final String agentSuffix = "SpringAI-CDBNoSQL-VectorStore"; @Bean public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties) { - return new CosmosClientBuilder().endpoint(properties.getEndpoint()) - .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore") - .key(properties.getKey()) - .gatewayMode() + String mode = properties.getConnectionMode(); + if (mode == null) { + properties.setConnectionMode("gateway"); + } + else if (!mode.equals("direct") && !mode.equals("gateway")) { + throw new IllegalArgumentException("Connection mode must be either 'direct' or 'gateway'"); + } + + CosmosClientBuilder builder = new CosmosClientBuilder().endpoint(properties.getEndpoint()) + .userAgentSuffix(agentSuffix); + + if (properties.getKey() == null || properties.getKey().isEmpty()) { + builder.credential(new DefaultAzureCredentialBuilder().build()); + } + else { + builder.key(properties.getKey()); + } + + return ("direct".equals(properties.getConnectionMode()) ? builder.directMode() : builder.gatewayMode()) .buildAsyncClient(); } @@ -78,12 +94,11 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe return CosmosDBVectorStore.builder(cosmosAsyncClient, embeddingModel) .databaseName(properties.getDatabaseName()) .containerName(properties.getContainerName()) - .metadataFields(List.of(properties.getMetadataFields())) + .metadataFields(properties.getMetadataFieldList()) .vectorStoreThroughput(properties.getVectorStoreThroughput()) .vectorDimensions(properties.getVectorDimensions()) .partitionKeyPath(properties.getPartitionKeyPath()) .build(); - } } diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java index 644df9322a1..ce84aa6eb23 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java @@ -19,13 +19,15 @@ import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; +import java.util.Arrays; +import java.util.List; + /** * Configuration properties for CosmosDB Vector Store. * * @author Theo van Kraay * @since 1.0.0 */ - @ConfigurationProperties(CosmosDBVectorStoreProperties.CONFIG_PREFIX) public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties { @@ -47,6 +49,8 @@ public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties { private String key; + private String connectionMode; + public int getVectorStoreThroughput() { return this.vectorStoreThroughput; } @@ -63,6 +67,12 @@ public void setMetadataFields(String metadataFields) { this.metadataFields = metadataFields; } + public List getMetadataFieldList() { + return this.metadataFields != null + ? Arrays.stream(this.metadataFields.split(",")).map(String::trim).filter(s -> !s.isEmpty()).toList() + : List.of(); + } + public String getEndpoint() { return this.endpoint; } @@ -79,6 +89,14 @@ public void setKey(String key) { this.key = key; } + public void setConnectionMode(String connectionMode) { + this.connectionMode = connectionMode; + } + + public String getConnectionMode() { + return this.connectionMode; + } + public String getDatabaseName() { return this.databaseName; } diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java index 69a0175f2f8..27e866f8f7a 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java @@ -44,22 +44,35 @@ * @author Theo van Kraay * @since 1.0.0 */ - @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+") public class CosmosDBVectorStoreAutoConfigurationIT { - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class)) - .withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384") - .withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + System.getenv("AZURE_COSMOSDB_ENDPOINT")) - .withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + System.getenv("AZURE_COSMOSDB_KEY")) - .withUserConfiguration(Config.class); + private final ApplicationContextRunner contextRunner; + + public CosmosDBVectorStoreAutoConfigurationIT() { + String endpoint = System.getenv("AZURE_COSMOSDB_ENDPOINT"); + String key = System.getenv("AZURE_COSMOSDB_KEY"); + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class)) + .withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database") + .withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container") + .withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id") + .withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city") + .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000") + .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384"); + + if (endpoint != null && !"null".equalsIgnoreCase(endpoint)) { + contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + endpoint); + } + + if (key != null && !"null".equalsIgnoreCase(key)) { + contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + key); + } + + this.contextRunner = contextRunner.withUserConfiguration(Config.class); + } private VectorStore vectorStore; @@ -124,14 +137,15 @@ void testSimilaritySearchWithFilter() { metadata4.put("country", "US"); metadata4.put("year", 2020); metadata4.put("city", "Sofia"); - Document document1 = new Document("1", "A document about the UK", metadata1); Document document2 = new Document("2", "A document about the Netherlands", metadata2); Document document3 = new Document("3", "A document about the US", metadata3); Document document4 = new Document("4", "A document about the US", metadata4); this.vectorStore.add(List.of(document1, document2, document3, document4)); + FilterExpressionBuilder b = new FilterExpressionBuilder(); + List results = this.vectorStore.similaritySearch(SearchRequest.builder() .query("The World") .topK(10) @@ -190,7 +204,7 @@ public void autoConfigurationEnabledByDefault() { @Test public void autoConfigurationEnabledWhenTypeIsAzureCosmosDB() { - this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmmos-db").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmos-db").run(context -> { assertThat(context.getBeansOfType(CosmosDBVectorStoreProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VectorStore.class)).isNotEmpty(); assertThat(context.getBean(VectorStore.class)).isInstanceOf(CosmosDBVectorStore.class); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc index 9676fa46ded..dcefc8be71e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc @@ -112,7 +112,7 @@ The following configuration properties are available for the Cosmos DB vector st | spring.ai.vectorstore.cosmosdb.vectorStoreThroughput | The throughput for the vector store. | spring.ai.vectorstore.cosmosdb.vectorDimensions | The number of dimensions for the vectors. | spring.ai.vectorstore.cosmosdb.endpoint | The endpoint for the Cosmos DB. -| spring.ai.vectorstore.cosmosdb.key | The key for the Cosmos DB. +| spring.ai.vectorstore.cosmosdb.key | The key for the Cosmos DB (if key is not present, [DefaultAzureCredential](https://learn.microsoft.com/azure/developer/java/sdk/authentication/credential-chains#defaultazurecredential-overview) will be used). |=== @@ -146,7 +146,7 @@ List results = vectorStore.similaritySearch(SearchRequest.builder().qu == Setting up Azure Cosmos DB Vector Store without Auto Configuration -The following code demonstrates how to set up the `CosmosDBVectorStore` without relying on auto-configuration: +The following code demonstrates how to set up the `CosmosDBVectorStore` without relying on auto-configuration. [DefaultAzureCredential](https://learn.microsoft.com/azure/developer/java/sdk/authentication/credential-chains#defaultazurecredential-overview) is recommended for authentication to Azure Cosmos DB. [source,java] ---- @@ -155,7 +155,7 @@ public VectorStore vectorStore(ObservationRegistry observationRegistry) { // Create the Cosmos DB client CosmosAsyncClient cosmosClient = new CosmosClientBuilder() .endpoint(System.getenv("COSMOSDB_AI_ENDPOINT")) - .key(System.getenv("COSMOSDB_AI_KEY")) + .credential(new DefaultAzureCredentialBuilder().build()) .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore") .gatewayMode() .buildAsyncClient(); diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml index ca93a9ece8b..441855a6527 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml +++ b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml @@ -47,6 +47,11 @@ azure-spring-data-cosmos ${azure-cosmos.version} + + com.azure + azure-identity + 1.15.4 + org.springframework.ai spring-ai-core diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java index 3a2f3edc3e6..6d476d869b6 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java @@ -18,7 +18,10 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -30,6 +33,7 @@ import com.azure.cosmos.models.CosmosBulkOperations; import com.azure.cosmos.models.CosmosContainerProperties; import com.azure.cosmos.models.CosmosItemOperation; +import com.azure.cosmos.models.CosmosItemResponse; import com.azure.cosmos.models.CosmosQueryRequestOptions; import com.azure.cosmos.models.CosmosVectorDataType; import com.azure.cosmos.models.CosmosVectorDistanceFunction; @@ -38,6 +42,7 @@ import com.azure.cosmos.models.CosmosVectorIndexSpec; import com.azure.cosmos.models.CosmosVectorIndexType; import com.azure.cosmos.models.ExcludedPath; +import com.azure.cosmos.models.FeedResponse; import com.azure.cosmos.models.IncludedPath; import com.azure.cosmos.models.IndexingMode; import com.azure.cosmos.models.IndexingPolicy; @@ -48,6 +53,7 @@ import com.azure.cosmos.models.SqlQuerySpec; import com.azure.cosmos.models.ThroughputProperties; import com.azure.cosmos.util.CosmosPagedFlux; + import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; @@ -117,7 +123,16 @@ protected CosmosDBVectorStore(Builder builder) { this.vectorDimensions = builder.vectorDimensions; this.metadataFieldsList = builder.metadataFieldsList; - this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block(); + try { + this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block(); + } + catch (Exception e) { + // likely failed due to RBAC, so database is assumed to be already created + // (and + // if not, it will fail later) + logger.error("Error creating database: {}", e.getMessage()); + } + initializeContainer(this.containerName, this.databaseName, this.vectorStoreThroughput, this.vectorDimensions, this.partitionKeyPath); } @@ -223,10 +238,30 @@ public void doAdd(List documents) { // Create a list to hold both the CosmosItemOperation and the corresponding // document ID List> itemOperationsWithIds = documents.stream().map(doc -> { + String partitionKeyValue; + + if ("/id".equals(this.partitionKeyPath)) { + partitionKeyValue = doc.getId(); + } + else if (this.partitionKeyPath.startsWith("/metadata/")) { + // Extract the key, e.g. "/metadata/country" -> "country" + String metadataKey = this.partitionKeyPath.substring("/metadata/".length()); + Object value = doc.getMetadata() != null ? doc.getMetadata().get(metadataKey) : null; + if (value == null) { + throw new IllegalArgumentException( + "Partition key '" + metadataKey + "' not found in document metadata."); + } + partitionKeyValue = value.toString(); + } + else { + throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath); + } + CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation( - mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId())); - return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID + mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), + new PartitionKey(partitionKeyValue)); // Pair the document ID // with the operation + return new ImmutablePair<>(doc.getId(), operation); }).toList(); try { @@ -273,9 +308,48 @@ public void doAdd(List documents) { public void doDelete(List idList) { try { // Convert the list of IDs into bulk delete operations - List itemOperations = idList.stream() - .map(id -> CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(id))) - .collect(Collectors.toList()); + List itemOperations = idList.stream().map(id -> { + String partitionKeyValue; + + if ("/id".equals(this.partitionKeyPath)) { + partitionKeyValue = id; + } + + else if (this.partitionKeyPath.startsWith("/metadata/")) { + // Will be inefficient for large numbers of documents but there is no + // other way to get the partition key value + // with current method signature. Ideally, we should be able to pass + // the partition key value directly. + String metadataKey = this.partitionKeyPath.substring("/metadata/".length()); + + // Run a reactive query to fetch the document by ID + String query = String.format("SELECT * FROM c WHERE c.id = '%s'", id); + CosmosPagedFlux queryFlux = this.container.queryItems(query, + new CosmosQueryRequestOptions(), JsonNode.class); + + // Block to retrieve the first page synchronously + List documents = queryFlux.byPage(1).blockFirst().getResults(); + + if (documents == null || documents.isEmpty()) { + throw new IllegalArgumentException("No document found for id: " + id); + } + + JsonNode document = documents.get(0); + JsonNode metadataNode = document.get("metadata"); + + if (metadataNode == null || metadataNode.get(metadataKey) == null) { + throw new IllegalArgumentException("Partition key '" + metadataKey + + "' not found in metadata for document with id: " + id); + } + + partitionKeyValue = metadataNode.get(metadataKey).asText(); + } + else { + throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath); + } + + return CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(partitionKeyValue)); + }).collect(Collectors.toList()); // Execute bulk delete operations synchronously by using blockLast() on the // Flux @@ -283,10 +357,11 @@ public void doDelete(List idList) { .doOnNext(response -> logger.info("Document deleted with status: {}", response.getResponse().getStatusCode())) .doOnError(error -> logger.error("Error deleting document: {}", error.getMessage())) - .blockLast(); // This will block until all operations have finished + .blockLast(); } catch (Exception e) { - logger.error("Exception while deleting documents: {}", e.getMessage()); + logger.error("Exception while deleting documents: {}", e.getMessage(), e); + throw e; } } @@ -348,9 +423,26 @@ public List doSimilaritySearch(SearchRequest request) { .flatMap(page -> Flux.fromIterable(page.getResults())) .collectList() .block(); + + // Collect metadata fields from the documents + Map docFields = new HashMap<>(); + for (var doc : documents) { + JsonNode metadata = doc.get("metadata"); + metadata.fieldNames().forEachRemaining(field -> { + JsonNode value = metadata.get(field); + Object parsedValue = value.isTextual() ? value.asText() : value.isNumber() ? value.numberValue() + : value.isBoolean() ? value.booleanValue() : value.toString(); + docFields.put(field, parsedValue); + }); + } + // Convert JsonNode to Document List docs = documents.stream() - .map(doc -> Document.builder().id(doc.get("id").asText()).text(doc.get("content").asText()).build()) + .map(doc -> Document.builder() + .id(doc.get("id").asText()) + .text(doc.get("content").asText()) + .metadata(docFields) + .build()) .collect(Collectors.toList()); return docs != null ? docs : List.of(); @@ -475,7 +567,7 @@ public Builder vectorDimensions(long vectorDimensions) { * @return the builder instance */ public Builder metadataFields(List metadataFieldsList) { - this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(this.metadataFieldsList) + this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(metadataFieldsList) : new ArrayList<>(); return this; } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java index 39729c6806e..6a46dd342db 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java @@ -25,6 +25,7 @@ import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosAsyncContainer; import com.azure.cosmos.CosmosClientBuilder; +import com.azure.identity.DefaultAzureCredentialBuilder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -50,7 +51,6 @@ * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") -@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+") public class CosmosDBVectorStoreIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -141,6 +141,11 @@ void testSimilaritySearchWithFilter() { assertThat(results).hasSize(2); assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); + for (Document doc : results) { + assertThat(doc.getMetadata().get("country")).isIn("UK", "NL"); + assertThat(doc.getMetadata().get("year")).isIn(2021, 2022); + assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia"); + } List results2 = this.vectorStore.similaritySearch(SearchRequest.builder() .query("The World") @@ -191,6 +196,7 @@ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel em .databaseName("test-database") .containerName("test-container") .metadataFields(List.of("country", "year", "city")) + .partitionKeyPath("/id") .vectorStoreThroughput(1000) .customObservationConvention(convention) .build(); @@ -199,7 +205,7 @@ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel em @Bean public CosmosAsyncClient cosmosClient() { return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT")) - .key(System.getenv("AZURE_COSMOSDB_KEY")) + .credential(new DefaultAzureCredentialBuilder().build()) .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore") .gatewayMode() .buildAsyncClient(); diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java new file mode 100644 index 00000000000..5ec456871c7 --- /dev/null +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java @@ -0,0 +1,224 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.cosmosdb; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.identity.DefaultAzureCredentialBuilder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Theo van Kraay + * @author Thomas Vitale + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") +public class CosmosDBVectorStoreWithMetadataPartitionKeyIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private VectorStore vectorStore; + + @BeforeEach + public void setup() { + this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class)); + } + + @Test + public void testAddSearchAndDeleteDocuments() { + + // Create a sample document + Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1")); + assertThatThrownBy(() -> this.vectorStore.add(List.of(document1))).isInstanceOf(Exception.class) + .hasMessageContaining("Partition key 'country' not found in document metadata."); + + Document document2 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("country", "UK")); + this.vectorStore.add(List.of(document2)); + + // Perform a similarity search + List results = this.vectorStore + .similaritySearch(SearchRequest.builder().query("Sample content1").topK(1).build()); + + // Verify the search results + assertThat(results).isNotEmpty(); + assertThat(results.get(0).getId()).isEqualTo(document2.getId()); + + // Remove the documents from the vector store + this.vectorStore.delete(List.of(document2.getId())); + + // Perform a similarity search again + List results2 = this.vectorStore + .similaritySearch(SearchRequest.builder().query("Sample content").topK(1).build()); + + // Verify the search results + assertThat(results2).isEmpty(); + + } + + @Test + void testSimilaritySearchWithFilter() { + + // Insert documents using vectorStore.add + Map metadata1; + metadata1 = new HashMap<>(); + metadata1.put("country", "UK"); + metadata1.put("year", 2021); + metadata1.put("city", "London"); + + Map metadata2; + metadata2 = new HashMap<>(); + metadata2.put("country", "NL"); + metadata2.put("year", 2022); + metadata2.put("city", "Amsterdam"); + + Map metadata3; + metadata3 = new HashMap<>(); + metadata3.put("country", "US"); + metadata3.put("year", 2019); + metadata3.put("city", "Sofia"); + + Map metadata4; + metadata4 = new HashMap<>(); + metadata4.put("country", "US"); + metadata4.put("year", 2020); + metadata4.put("city", "Sofia"); + + Document document1 = new Document("1", "A document about the UK", metadata1); + Document document2 = new Document("2", "A document about the Netherlands", metadata2); + Document document3 = new Document("3", "A document about the US", metadata3); + Document document4 = new Document("4", "A document about the US", metadata4); + + this.vectorStore.add(List.of(document1, document2, document3, document4)); + FilterExpressionBuilder b = new FilterExpressionBuilder(); + List results = this.vectorStore.similaritySearch(SearchRequest.builder() + .query("The World") + .topK(10) + .filterExpression((b.in("country", "UK", "NL")).build()) + .build()); + + assertThat(results).hasSize(2); + assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); + for (Document doc : results) { + assertThat(doc.getMetadata().get("country")).isIn("UK", "NL"); + assertThat(doc.getMetadata().get("year")).isIn(2021, 2022); + assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia"); + } + + List results2 = this.vectorStore.similaritySearch(SearchRequest.builder() + .query("The World") + .topK(10) + .filterExpression( + b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build()) + .build()); + + assertThat(results2).hasSize(1); + assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1"); + + List results3 = this.vectorStore.similaritySearch(SearchRequest.builder() + .query("The World") + .topK(10) + .filterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build()) + .build()); + + assertThat(results3).hasSize(1); + assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4"); + + this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); + + // Perform a similarity search again + List results4 = this.vectorStore + .similaritySearch(SearchRequest.builder().query("The World").topK(1).build()); + + // Verify the search results + assertThat(results4).isEmpty(); + } + + @Test + void getNativeClientTest() { + this.contextRunner.run(context -> { + CosmosDBVectorStore vectorStore = context.getBean(CosmosDBVectorStore.class); + Optional nativeClient = vectorStore.getNativeClient(); + assertThat(nativeClient).isPresent(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration + public static class TestApplication { + + @Bean + public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel, + VectorStoreObservationConvention convention) { + return CosmosDBVectorStore.builder(cosmosClient, embeddingModel) + .databaseName("test-database") + .containerName("test-container-metadata-partition-key") + .metadataFields(List.of("country", "year", "city")) + .partitionKeyPath("/metadata/country") + .vectorStoreThroughput(1000) + .customObservationConvention(convention) + .build(); + } + + @Bean + public CosmosAsyncClient cosmosClient() { + return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT")) + .credential(new DefaultAzureCredentialBuilder().build()) + .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore") + .gatewayMode() + .buildAsyncClient(); + } + + @Bean + public EmbeddingModel embeddingModel() { + return new TransformersEmbeddingModel(); + } + + @Bean + public VectorStoreObservationConvention observationConvention() { + // Replace with an actual observation convention or a mock if needed + return new VectorStoreObservationConvention() { + + }; + } + + } + +}