From 000354e8db6d69c5792c76ba41d5bd95cc0a2821 Mon Sep 17 00:00:00 2001 From: Ido Berkovich Date: Wed, 15 Jan 2025 09:23:50 +0200 Subject: [PATCH] OPIK-744 ignore default --- .../opik/domain/llmproviders/LlmProviderFactory.java | 10 ++++++---- .../domain/llmproviders/LlmProviderFactoryTest.java | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java index 9ef99fd364..5d392ef607 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java @@ -14,6 +14,7 @@ import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.EnumUtils; +import java.util.Set; import java.util.function.Function; @Singleton @@ -50,13 +51,13 @@ public ChatLanguageModel getLanguageModel(@NonNull String workspaceId, * The agreed requirement is to resolve the LLM provider and its API key based on the model. */ private LlmProvider getLlmProvider(String model) { - if (isModelBelongToProvider(model, ModelPrice.class, ModelPrice::getName)) { + if (isModelBelongToProvider(model, ModelPrice.class, ModelPrice::getName, Set.of(ModelPrice.DEFAULT))) { return LlmProvider.OPEN_AI; } - if (isModelBelongToProvider(model, AnthropicModelName.class, AnthropicModelName::toString)) { + if (isModelBelongToProvider(model, AnthropicModelName.class, AnthropicModelName::toString, Set.of())) { return LlmProvider.ANTHROPIC; } - if (isModelBelongToProvider(model, GeminiModelName.class, GeminiModelName::toString)) { + if (isModelBelongToProvider(model, GeminiModelName.class, GeminiModelName::toString, Set.of())) { return LlmProvider.GEMINI; } @@ -77,8 +78,9 @@ private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) { } private static > boolean isModelBelongToProvider( - String model, Class enumClass, Function valueGetter) { + String model, Class enumClass, Function valueGetter, Set exclude) { return EnumUtils.getEnumList(enumClass).stream() + .filter(value -> !exclude.contains(value)) .map(valueGetter) .anyMatch(model::equals); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java index ffe44752a1..89d3d1e710 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java @@ -79,6 +79,7 @@ void testGetService(String model, LlmProvider llmProvider, Class testGetService() { var openAiModels = EnumUtils.getEnumList(ModelPrice.class).stream() + .filter(value -> value != ModelPrice.DEFAULT) .map(model -> arguments(model.getName(), LlmProvider.OPEN_AI, LlmProviderOpenAi.class)); var anthropicModels = EnumUtils.getEnumList(AnthropicModelName.class).stream() .map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, LlmProviderAnthropic.class));