From 4256a6bd2167c57b61e775a350414ae0e1d53e79 Mon Sep 17 00:00:00 2001 From: yihong0618 Date: Fri, 20 Dec 2024 09:22:20 +0800 Subject: [PATCH] fix: better gard nan value from numpy for issue #11827 Signed-off-by: yihong0618 --- .../azure_openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/cohere/text_embedding/text_embedding.py | 5 ++++- .../model_providers/openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/upstage/text_embedding/text_embedding.py | 5 ++++- api/core/rag/embedding/cached_embedding.py | 2 ++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index c45ce87ea76838..69d2cfaded453f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -92,7 +92,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 5fd4d637be7643..9e4df2706080f9 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -88,7 +88,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index bec01fe6797f52..9c8c8d5882ee4e 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -97,7 +97,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 7dd495b55ef4e6..5b340e53bbc543 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -100,7 +100,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 652f7e145fd94d..8ddda7e9832d97 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -116,6 +116,8 @@ def embed_query(self, text: str) -> list[float]: embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")