diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml
index 61980a98f46..067b8ccb507 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml
@@ -45,6 +45,11 @@
${project.parent.version}
true
+
+ com.azure
+ azure-identity
+ 1.15.4
+
org.springframework.boot
spring-boot-starter
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java
index 5fdc1686b33..bbf3500802b 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java
@@ -18,6 +18,7 @@
import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.identity.DefaultAzureCredentialBuilder;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.embedding.BatchingStrategy;
@@ -33,6 +34,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
+
import java.util.List;
/**
@@ -42,6 +44,7 @@
* @author Soby Chacko
* @since 1.0.0
*/
+
@AutoConfiguration
@ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class })
@EnableConfigurationProperties(CosmosDBVectorStoreProperties.class)
@@ -49,16 +52,29 @@
matchIfMissing = true)
public class CosmosDBVectorStoreAutoConfiguration {
- String endpoint;
-
- String key;
+ private final String agentSuffix = "SpringAI-CDBNoSQL-VectorStore";
@Bean
public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties) {
- return new CosmosClientBuilder().endpoint(properties.getEndpoint())
- .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
- .key(properties.getKey())
- .gatewayMode()
+ String mode = properties.getConnectionMode();
+ if (mode == null) {
+ properties.setConnectionMode("gateway");
+ }
+ else if (!mode.equals("direct") && !mode.equals("gateway")) {
+ throw new IllegalArgumentException("Connection mode must be either 'direct' or 'gateway'");
+ }
+
+ CosmosClientBuilder builder = new CosmosClientBuilder().endpoint(properties.getEndpoint())
+ .userAgentSuffix(agentSuffix);
+
+ if (properties.getKey() == null || properties.getKey().isEmpty()) {
+ builder.credential(new DefaultAzureCredentialBuilder().build());
+ }
+ else {
+ builder.key(properties.getKey());
+ }
+
+ return ("direct".equals(properties.getConnectionMode()) ? builder.directMode() : builder.gatewayMode())
.buildAsyncClient();
}
@@ -78,12 +94,11 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe
return CosmosDBVectorStore.builder(cosmosAsyncClient, embeddingModel)
.databaseName(properties.getDatabaseName())
.containerName(properties.getContainerName())
- .metadataFields(List.of(properties.getMetadataFields()))
+ .metadataFields(properties.getMetadataFieldList())
.vectorStoreThroughput(properties.getVectorStoreThroughput())
.vectorDimensions(properties.getVectorDimensions())
.partitionKeyPath(properties.getPartitionKeyPath())
.build();
-
}
}
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java
index 644df9322a1..ce84aa6eb23 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java
@@ -19,13 +19,15 @@
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
+import java.util.Arrays;
+import java.util.List;
+
/**
* Configuration properties for CosmosDB Vector Store.
*
* @author Theo van Kraay
* @since 1.0.0
*/
-
@ConfigurationProperties(CosmosDBVectorStoreProperties.CONFIG_PREFIX)
public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties {
@@ -47,6 +49,8 @@ public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties {
private String key;
+ private String connectionMode;
+
public int getVectorStoreThroughput() {
return this.vectorStoreThroughput;
}
@@ -63,6 +67,12 @@ public void setMetadataFields(String metadataFields) {
this.metadataFields = metadataFields;
}
+ public List getMetadataFieldList() {
+ return this.metadataFields != null
+ ? Arrays.stream(this.metadataFields.split(",")).map(String::trim).filter(s -> !s.isEmpty()).toList()
+ : List.of();
+ }
+
public String getEndpoint() {
return this.endpoint;
}
@@ -79,6 +89,14 @@ public void setKey(String key) {
this.key = key;
}
+ public void setConnectionMode(String connectionMode) {
+ this.connectionMode = connectionMode;
+ }
+
+ public String getConnectionMode() {
+ return this.connectionMode;
+ }
+
public String getDatabaseName() {
return this.databaseName;
}
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java
index 69a0175f2f8..27e866f8f7a 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java
@@ -44,22 +44,35 @@
* @author Theo van Kraay
* @since 1.0.0
*/
-
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
public class CosmosDBVectorStoreAutoConfigurationIT {
- private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
- .withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class))
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384")
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + System.getenv("AZURE_COSMOSDB_ENDPOINT"))
- .withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + System.getenv("AZURE_COSMOSDB_KEY"))
- .withUserConfiguration(Config.class);
+ private final ApplicationContextRunner contextRunner;
+
+ public CosmosDBVectorStoreAutoConfigurationIT() {
+ String endpoint = System.getenv("AZURE_COSMOSDB_ENDPOINT");
+ String key = System.getenv("AZURE_COSMOSDB_KEY");
+
+ ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class))
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database")
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container")
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id")
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city")
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000")
+ .withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384");
+
+ if (endpoint != null && !"null".equalsIgnoreCase(endpoint)) {
+ contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + endpoint);
+ }
+
+ if (key != null && !"null".equalsIgnoreCase(key)) {
+ contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + key);
+ }
+
+ this.contextRunner = contextRunner.withUserConfiguration(Config.class);
+ }
private VectorStore vectorStore;
@@ -124,14 +137,15 @@ void testSimilaritySearchWithFilter() {
metadata4.put("country", "US");
metadata4.put("year", 2020);
metadata4.put("city", "Sofia");
-
Document document1 = new Document("1", "A document about the UK", metadata1);
Document document2 = new Document("2", "A document about the Netherlands", metadata2);
Document document3 = new Document("3", "A document about the US", metadata3);
Document document4 = new Document("4", "A document about the US", metadata4);
this.vectorStore.add(List.of(document1, document2, document3, document4));
+
FilterExpressionBuilder b = new FilterExpressionBuilder();
+
List results = this.vectorStore.similaritySearch(SearchRequest.builder()
.query("The World")
.topK(10)
@@ -190,7 +204,7 @@ public void autoConfigurationEnabledByDefault() {
@Test
public void autoConfigurationEnabledWhenTypeIsAzureCosmosDB() {
- this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmmos-db").run(context -> {
+ this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmos-db").run(context -> {
assertThat(context.getBeansOfType(CosmosDBVectorStoreProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(VectorStore.class)).isNotEmpty();
assertThat(context.getBean(VectorStore.class)).isInstanceOf(CosmosDBVectorStore.class);
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc
index 9676fa46ded..dcefc8be71e 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc
@@ -112,7 +112,7 @@ The following configuration properties are available for the Cosmos DB vector st
| spring.ai.vectorstore.cosmosdb.vectorStoreThroughput | The throughput for the vector store.
| spring.ai.vectorstore.cosmosdb.vectorDimensions | The number of dimensions for the vectors.
| spring.ai.vectorstore.cosmosdb.endpoint | The endpoint for the Cosmos DB.
-| spring.ai.vectorstore.cosmosdb.key | The key for the Cosmos DB.
+| spring.ai.vectorstore.cosmosdb.key | The key for the Cosmos DB (if key is not present, [DefaultAzureCredential](https://learn.microsoft.com/azure/developer/java/sdk/authentication/credential-chains#defaultazurecredential-overview) will be used).
|===
@@ -146,7 +146,7 @@ List results = vectorStore.similaritySearch(SearchRequest.builder().qu
== Setting up Azure Cosmos DB Vector Store without Auto Configuration
-The following code demonstrates how to set up the `CosmosDBVectorStore` without relying on auto-configuration:
+The following code demonstrates how to set up the `CosmosDBVectorStore` without relying on auto-configuration. [DefaultAzureCredential](https://learn.microsoft.com/azure/developer/java/sdk/authentication/credential-chains#defaultazurecredential-overview) is recommended for authentication to Azure Cosmos DB.
[source,java]
----
@@ -155,7 +155,7 @@ public VectorStore vectorStore(ObservationRegistry observationRegistry) {
// Create the Cosmos DB client
CosmosAsyncClient cosmosClient = new CosmosClientBuilder()
.endpoint(System.getenv("COSMOSDB_AI_ENDPOINT"))
- .key(System.getenv("COSMOSDB_AI_KEY"))
+ .credential(new DefaultAzureCredentialBuilder().build())
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
.gatewayMode()
.buildAsyncClient();
diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml
index ca93a9ece8b..441855a6527 100644
--- a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml
+++ b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml
@@ -47,6 +47,11 @@
azure-spring-data-cosmos
${azure-cosmos.version}
+
+ com.azure
+ azure-identity
+ 1.15.4
+
org.springframework.ai
spring-ai-core
diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java
index 3a2f3edc3e6..6d476d869b6 100644
--- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java
+++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java
@@ -18,7 +18,10 @@
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -30,6 +33,7 @@
import com.azure.cosmos.models.CosmosBulkOperations;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosItemOperation;
+import com.azure.cosmos.models.CosmosItemResponse;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
import com.azure.cosmos.models.CosmosVectorDataType;
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
@@ -38,6 +42,7 @@
import com.azure.cosmos.models.CosmosVectorIndexSpec;
import com.azure.cosmos.models.CosmosVectorIndexType;
import com.azure.cosmos.models.ExcludedPath;
+import com.azure.cosmos.models.FeedResponse;
import com.azure.cosmos.models.IncludedPath;
import com.azure.cosmos.models.IndexingMode;
import com.azure.cosmos.models.IndexingPolicy;
@@ -48,6 +53,7 @@
import com.azure.cosmos.models.SqlQuerySpec;
import com.azure.cosmos.models.ThroughputProperties;
import com.azure.cosmos.util.CosmosPagedFlux;
+
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
@@ -117,7 +123,16 @@ protected CosmosDBVectorStore(Builder builder) {
this.vectorDimensions = builder.vectorDimensions;
this.metadataFieldsList = builder.metadataFieldsList;
- this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block();
+ try {
+ this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block();
+ }
+ catch (Exception e) {
+ // likely failed due to RBAC, so database is assumed to be already created
+ // (and
+ // if not, it will fail later)
+ logger.error("Error creating database: {}", e.getMessage());
+ }
+
initializeContainer(this.containerName, this.databaseName, this.vectorStoreThroughput, this.vectorDimensions,
this.partitionKeyPath);
}
@@ -223,10 +238,30 @@ public void doAdd(List documents) {
// Create a list to hold both the CosmosItemOperation and the corresponding
// document ID
List> itemOperationsWithIds = documents.stream().map(doc -> {
+ String partitionKeyValue;
+
+ if ("/id".equals(this.partitionKeyPath)) {
+ partitionKeyValue = doc.getId();
+ }
+ else if (this.partitionKeyPath.startsWith("/metadata/")) {
+ // Extract the key, e.g. "/metadata/country" -> "country"
+ String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
+ Object value = doc.getMetadata() != null ? doc.getMetadata().get(metadataKey) : null;
+ if (value == null) {
+ throw new IllegalArgumentException(
+ "Partition key '" + metadataKey + "' not found in document metadata.");
+ }
+ partitionKeyValue = value.toString();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
+ }
+
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
- mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
- return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
+ mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))),
+ new PartitionKey(partitionKeyValue)); // Pair the document ID
// with the operation
+ return new ImmutablePair<>(doc.getId(), operation);
}).toList();
try {
@@ -273,9 +308,48 @@ public void doAdd(List documents) {
public void doDelete(List idList) {
try {
// Convert the list of IDs into bulk delete operations
- List itemOperations = idList.stream()
- .map(id -> CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(id)))
- .collect(Collectors.toList());
+ List itemOperations = idList.stream().map(id -> {
+ String partitionKeyValue;
+
+ if ("/id".equals(this.partitionKeyPath)) {
+ partitionKeyValue = id;
+ }
+
+ else if (this.partitionKeyPath.startsWith("/metadata/")) {
+ // Will be inefficient for large numbers of documents but there is no
+ // other way to get the partition key value
+ // with current method signature. Ideally, we should be able to pass
+ // the partition key value directly.
+ String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
+
+ // Run a reactive query to fetch the document by ID
+ String query = String.format("SELECT * FROM c WHERE c.id = '%s'", id);
+ CosmosPagedFlux queryFlux = this.container.queryItems(query,
+ new CosmosQueryRequestOptions(), JsonNode.class);
+
+ // Block to retrieve the first page synchronously
+ List documents = queryFlux.byPage(1).blockFirst().getResults();
+
+ if (documents == null || documents.isEmpty()) {
+ throw new IllegalArgumentException("No document found for id: " + id);
+ }
+
+ JsonNode document = documents.get(0);
+ JsonNode metadataNode = document.get("metadata");
+
+ if (metadataNode == null || metadataNode.get(metadataKey) == null) {
+ throw new IllegalArgumentException("Partition key '" + metadataKey
+ + "' not found in metadata for document with id: " + id);
+ }
+
+ partitionKeyValue = metadataNode.get(metadataKey).asText();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
+ }
+
+ return CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(partitionKeyValue));
+ }).collect(Collectors.toList());
// Execute bulk delete operations synchronously by using blockLast() on the
// Flux
@@ -283,10 +357,11 @@ public void doDelete(List idList) {
.doOnNext(response -> logger.info("Document deleted with status: {}",
response.getResponse().getStatusCode()))
.doOnError(error -> logger.error("Error deleting document: {}", error.getMessage()))
- .blockLast(); // This will block until all operations have finished
+ .blockLast();
}
catch (Exception e) {
- logger.error("Exception while deleting documents: {}", e.getMessage());
+ logger.error("Exception while deleting documents: {}", e.getMessage(), e);
+ throw e;
}
}
@@ -348,9 +423,26 @@ public List doSimilaritySearch(SearchRequest request) {
.flatMap(page -> Flux.fromIterable(page.getResults()))
.collectList()
.block();
+
+ // Collect metadata fields from the documents
+ Map docFields = new HashMap<>();
+ for (var doc : documents) {
+ JsonNode metadata = doc.get("metadata");
+ metadata.fieldNames().forEachRemaining(field -> {
+ JsonNode value = metadata.get(field);
+ Object parsedValue = value.isTextual() ? value.asText() : value.isNumber() ? value.numberValue()
+ : value.isBoolean() ? value.booleanValue() : value.toString();
+ docFields.put(field, parsedValue);
+ });
+ }
+
// Convert JsonNode to Document
List docs = documents.stream()
- .map(doc -> Document.builder().id(doc.get("id").asText()).text(doc.get("content").asText()).build())
+ .map(doc -> Document.builder()
+ .id(doc.get("id").asText())
+ .text(doc.get("content").asText())
+ .metadata(docFields)
+ .build())
.collect(Collectors.toList());
return docs != null ? docs : List.of();
@@ -475,7 +567,7 @@ public Builder vectorDimensions(long vectorDimensions) {
* @return the builder instance
*/
public Builder metadataFields(List metadataFieldsList) {
- this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(this.metadataFieldsList)
+ this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(metadataFieldsList)
: new ArrayList<>();
return this;
}
diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java
index 39729c6806e..6a46dd342db 100644
--- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java
+++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java
@@ -25,6 +25,7 @@
import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosAsyncContainer;
import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.identity.DefaultAzureCredentialBuilder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -50,7 +51,6 @@
* @since 1.0.0
*/
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
-@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
public class CosmosDBVectorStoreIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
@@ -141,6 +141,11 @@ void testSimilaritySearchWithFilter() {
assertThat(results).hasSize(2);
assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
+ for (Document doc : results) {
+ assertThat(doc.getMetadata().get("country")).isIn("UK", "NL");
+ assertThat(doc.getMetadata().get("year")).isIn(2021, 2022);
+ assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia");
+ }
List results2 = this.vectorStore.similaritySearch(SearchRequest.builder()
.query("The World")
@@ -191,6 +196,7 @@ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel em
.databaseName("test-database")
.containerName("test-container")
.metadataFields(List.of("country", "year", "city"))
+ .partitionKeyPath("/id")
.vectorStoreThroughput(1000)
.customObservationConvention(convention)
.build();
@@ -199,7 +205,7 @@ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel em
@Bean
public CosmosAsyncClient cosmosClient() {
return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT"))
- .key(System.getenv("AZURE_COSMOSDB_KEY"))
+ .credential(new DefaultAzureCredentialBuilder().build())
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
.gatewayMode()
.buildAsyncClient();
diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java
new file mode 100644
index 00000000000..5ec456871c7
--- /dev/null
+++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreWithMetadataPartitionKeyIT.java
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.vectorstore.cosmosdb;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.identity.DefaultAzureCredentialBuilder;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.transformers.TransformersEmbeddingModel;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
+import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * @author Theo van Kraay
+ * @author Thomas Vitale
+ * @since 1.0.0
+ */
+@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
+public class CosmosDBVectorStoreWithMetadataPartitionKeyIT {
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withUserConfiguration(TestApplication.class);
+
+ private VectorStore vectorStore;
+
+ @BeforeEach
+ public void setup() {
+ this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class));
+ }
+
+ @Test
+ public void testAddSearchAndDeleteDocuments() {
+
+ // Create a sample document
+ Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1"));
+ assertThatThrownBy(() -> this.vectorStore.add(List.of(document1))).isInstanceOf(Exception.class)
+ .hasMessageContaining("Partition key 'country' not found in document metadata.");
+
+ Document document2 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("country", "UK"));
+ this.vectorStore.add(List.of(document2));
+
+ // Perform a similarity search
+ List results = this.vectorStore
+ .similaritySearch(SearchRequest.builder().query("Sample content1").topK(1).build());
+
+ // Verify the search results
+ assertThat(results).isNotEmpty();
+ assertThat(results.get(0).getId()).isEqualTo(document2.getId());
+
+ // Remove the documents from the vector store
+ this.vectorStore.delete(List.of(document2.getId()));
+
+ // Perform a similarity search again
+ List results2 = this.vectorStore
+ .similaritySearch(SearchRequest.builder().query("Sample content").topK(1).build());
+
+ // Verify the search results
+ assertThat(results2).isEmpty();
+
+ }
+
+ @Test
+ void testSimilaritySearchWithFilter() {
+
+ // Insert documents using vectorStore.add
+ Map metadata1;
+ metadata1 = new HashMap<>();
+ metadata1.put("country", "UK");
+ metadata1.put("year", 2021);
+ metadata1.put("city", "London");
+
+ Map metadata2;
+ metadata2 = new HashMap<>();
+ metadata2.put("country", "NL");
+ metadata2.put("year", 2022);
+ metadata2.put("city", "Amsterdam");
+
+ Map metadata3;
+ metadata3 = new HashMap<>();
+ metadata3.put("country", "US");
+ metadata3.put("year", 2019);
+ metadata3.put("city", "Sofia");
+
+ Map metadata4;
+ metadata4 = new HashMap<>();
+ metadata4.put("country", "US");
+ metadata4.put("year", 2020);
+ metadata4.put("city", "Sofia");
+
+ Document document1 = new Document("1", "A document about the UK", metadata1);
+ Document document2 = new Document("2", "A document about the Netherlands", metadata2);
+ Document document3 = new Document("3", "A document about the US", metadata3);
+ Document document4 = new Document("4", "A document about the US", metadata4);
+
+ this.vectorStore.add(List.of(document1, document2, document3, document4));
+ FilterExpressionBuilder b = new FilterExpressionBuilder();
+ List results = this.vectorStore.similaritySearch(SearchRequest.builder()
+ .query("The World")
+ .topK(10)
+ .filterExpression((b.in("country", "UK", "NL")).build())
+ .build());
+
+ assertThat(results).hasSize(2);
+ assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
+ for (Document doc : results) {
+ assertThat(doc.getMetadata().get("country")).isIn("UK", "NL");
+ assertThat(doc.getMetadata().get("year")).isIn(2021, 2022);
+ assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia");
+ }
+
+ List results2 = this.vectorStore.similaritySearch(SearchRequest.builder()
+ .query("The World")
+ .topK(10)
+ .filterExpression(
+ b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())
+ .build());
+
+ assertThat(results2).hasSize(1);
+ assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1");
+
+ List results3 = this.vectorStore.similaritySearch(SearchRequest.builder()
+ .query("The World")
+ .topK(10)
+ .filterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())
+ .build());
+
+ assertThat(results3).hasSize(1);
+ assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4");
+
+ this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId()));
+
+ // Perform a similarity search again
+ List results4 = this.vectorStore
+ .similaritySearch(SearchRequest.builder().query("The World").topK(1).build());
+
+ // Verify the search results
+ assertThat(results4).isEmpty();
+ }
+
+ @Test
+ void getNativeClientTest() {
+ this.contextRunner.run(context -> {
+ CosmosDBVectorStore vectorStore = context.getBean(CosmosDBVectorStore.class);
+ Optional nativeClient = vectorStore.getNativeClient();
+ assertThat(nativeClient).isPresent();
+ });
+ }
+
+ @SpringBootConfiguration
+ @EnableAutoConfiguration
+ public static class TestApplication {
+
+ @Bean
+ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel,
+ VectorStoreObservationConvention convention) {
+ return CosmosDBVectorStore.builder(cosmosClient, embeddingModel)
+ .databaseName("test-database")
+ .containerName("test-container-metadata-partition-key")
+ .metadataFields(List.of("country", "year", "city"))
+ .partitionKeyPath("/metadata/country")
+ .vectorStoreThroughput(1000)
+ .customObservationConvention(convention)
+ .build();
+ }
+
+ @Bean
+ public CosmosAsyncClient cosmosClient() {
+ return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT"))
+ .credential(new DefaultAzureCredentialBuilder().build())
+ .userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
+ .gatewayMode()
+ .buildAsyncClient();
+ }
+
+ @Bean
+ public EmbeddingModel embeddingModel() {
+ return new TransformersEmbeddingModel();
+ }
+
+ @Bean
+ public VectorStoreObservationConvention observationConvention() {
+ // Replace with an actual observation convention or a mock if needed
+ return new VectorStoreObservationConvention() {
+
+ };
+ }
+
+ }
+
+}