Skip to content

fix: adding cache for embedding models in EmbeddingModelFactory #601

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 4 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
Expand Down Expand Up @@ -46,6 +47,8 @@
public class EmbeddingModelFactory {
private final AIRedisOMProperties properties;
private final SpringAiProperties springAiProperties;
private final Map<String, Object> modelCache = new ConcurrentHashMap<>();

private final RestClient.Builder restClientBuilder;
private final WebClient.Builder webClientBuilder;
private final ResponseErrorHandler responseErrorHandler;
Expand All @@ -62,7 +65,60 @@ public EmbeddingModelFactory(AIRedisOMProperties properties, SpringAiProperties
this.observationRegistry = observationRegistry;
}

/**
* Generates a cache key for a model based on its type and parameters
*
* @param modelType The type of the model
* @param params Parameters that uniquely identify the model configuration
* @return A string key for caching
*/
private String generateCacheKey(String modelType, String... params) {
StringBuilder keyBuilder = new StringBuilder(modelType);
for (String param : params) {
keyBuilder.append(":").append(param);
}
return keyBuilder.toString();
}

/**
* Clears the model cache, forcing new models to be created on next request.
* This can be useful when configuration changes or to free up resources.
*/
public void clearCache() {
modelCache.clear();
}

/**
* Removes a specific model from the cache.
*
* @param modelType The type of the model (e.g., "openai", "transformers")
* @param params Parameters that were used to create the model
* @return true if a model was removed, false otherwise
*/
public boolean removeFromCache(String modelType, String... params) {
String cacheKey = generateCacheKey(modelType, params);
return modelCache.remove(cacheKey) != null;
}

/**
* Returns the current number of models in the cache.
*
* @return The number of cached models
*/
public int getCacheSize() {
return modelCache.size();
}

public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vectorize) {
String cacheKey = generateCacheKey("transformers", vectorize.transformersModel(), vectorize.transformersTokenizer(),
vectorize.transformersResourceCacheConfiguration(), String.join(",", vectorize.transformersTokenizerOptions()));

TransformersEmbeddingModel cachedModel = (TransformersEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();

if (!vectorize.transformersModel().isEmpty()) {
Expand All @@ -89,6 +145,8 @@ public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vec
throw new RuntimeException("Error initializing TransformersEmbeddingModel", e);
}

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}

Expand All @@ -97,6 +155,13 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(EmbeddingModel model) {
}

public OpenAiEmbeddingModel createOpenAiEmbeddingModel(String model) {
String cacheKey = generateCacheKey("openai", model, properties.getOpenAi().getApiKey());
OpenAiEmbeddingModel cachedModel = (OpenAiEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

String apiKey = properties.getOpenAi().getApiKey();
if (!StringUtils.hasText(apiKey)) {
apiKey = springAiProperties.getOpenai().getApiKey();
Expand All @@ -109,8 +174,11 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(String model) {
OpenAiApi openAiApi = OpenAiApi.builder().apiKey(properties.getOpenAi().getApiKey()).restClientBuilder(RestClient
.builder().requestFactory(factory)).build();

return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder().model(model)
.build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
OpenAiEmbeddingModel embeddingModel = new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions
.builder().model(model).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);

modelCache.put(cacheKey, embeddingModel);
return embeddingModel;
}

private OpenAIClient getOpenAIClient() {
Expand All @@ -126,6 +194,16 @@ private OpenAIClient getOpenAIClient() {
}

public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel(String deploymentName) {
String cacheKey = generateCacheKey("azure-openai", deploymentName, properties.getAzure().getOpenAi().getApiKey(),
properties.getAzure().getOpenAi().getEndpoint(), String.valueOf(properties.getAzure().getEntraId()
.isEnabled()));

AzureOpenAiEmbeddingModel cachedModel = (AzureOpenAiEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

String apiKey = properties.getAzure().getOpenAi().getApiKey();
if (!StringUtils.hasText(apiKey)) {
apiKey = springAiProperties.getAzure().getApiKey(); // Fallback to Spring AI property
Expand All @@ -142,10 +220,23 @@ public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel(String deployme

AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder().deploymentName(deploymentName).build();

return new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED, options);
AzureOpenAiEmbeddingModel embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED, options);

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}

public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model) {
String cacheKey = generateCacheKey("vertex-ai", model, properties.getVertexAi().getApiKey(), properties
.getVertexAi().getEndpoint(), properties.getVertexAi().getProjectId(), properties.getVertexAi().getLocation());

VertexAiTextEmbeddingModel cachedModel = (VertexAiTextEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

String apiKey = properties.getVertexAi().getApiKey();
if (!StringUtils.hasText(apiKey)) {
apiKey = springAiProperties.getVertexAi().getApiKey(); // Fallback to Spring AI property
Expand Down Expand Up @@ -183,16 +274,32 @@ public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model)

VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder().model(model).build();

return new VertexAiTextEmbeddingModel(connectionDetails, options);
VertexAiTextEmbeddingModel embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, options);

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}

public OllamaEmbeddingModel createOllamaEmbeddingModel(String model) {
String cacheKey = generateCacheKey("ollama", model, properties.getOllama().getBaseUrl());

OllamaEmbeddingModel cachedModel = (OllamaEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

OllamaApi api = OllamaApi.builder().baseUrl(properties.getOllama().getBaseUrl()).restClientBuilder(
restClientBuilder).webClientBuilder(webClientBuilder).responseErrorHandler(responseErrorHandler).build();

OllamaOptions options = OllamaOptions.builder().model(model).truncate(false).build();

return OllamaEmbeddingModel.builder().ollamaApi(api).defaultOptions(options).build();
OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder().ollamaApi(api).defaultOptions(options).build();

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}

private AwsCredentials getAwsCredentials() {
Expand All @@ -218,6 +325,16 @@ private AwsCredentials getAwsCredentials() {
}

public BedrockCohereEmbeddingModel createCohereEmbeddingModel(String model) {
String cacheKey = generateCacheKey("bedrock-cohere", model, properties.getAws().getAccessKey(), properties.getAws()
.getSecretKey(), properties.getAws().getRegion(), String.valueOf(properties.getAws().getBedrockCohere()
.getResponseTimeOut()));

BedrockCohereEmbeddingModel cachedModel = (BedrockCohereEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

String region = properties.getAws().getRegion();
if (!StringUtils.hasText(region)) {
region = springAiProperties.getBedrock().getAws().getRegion(); // Fallback to Spring AI property
Expand All @@ -228,10 +345,24 @@ public BedrockCohereEmbeddingModel createCohereEmbeddingModel(String model) {
properties.getAws().getRegion(), ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(properties.getAws()
.getBedrockCohere().getResponseTimeOut()));

return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
BedrockCohereEmbeddingModel embeddingModel = new BedrockCohereEmbeddingModel(cohereEmbeddingApi);

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}

public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
String cacheKey = generateCacheKey("bedrock-titan", model, properties.getAws().getAccessKey(), properties.getAws()
.getSecretKey(), properties.getAws().getRegion(), String.valueOf(properties.getAws().getBedrockTitan()
.getResponseTimeOut()));

BedrockTitanEmbeddingModel cachedModel = (BedrockTitanEmbeddingModel) modelCache.get(cacheKey);

if (cachedModel != null) {
return cachedModel;
}

String region = properties.getAws().getRegion();
if (!StringUtils.hasText(region)) {
region = springAiProperties.getBedrock().getAws().getRegion(); // Fallback to Spring AI property
Expand All @@ -242,6 +373,10 @@ public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
properties.getAws().getRegion(), ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(properties.getAws()
.getBedrockTitan().getResponseTimeOut()));

return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
BedrockTitanEmbeddingModel embeddingModel = new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);

modelCache.put(cacheKey, embeddingModel);

return embeddingModel;
}
}
2 changes: 1 addition & 1 deletion tests/src/test/resources/vss_on.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ redis:
om:
spring:
ai:
\enabled: true
enabled: true