From b512d8febad8c40d506582afc9c900ba591563c3 Mon Sep 17 00:00:00 2001 From: Ophiuchus Date: Sun, 10 Nov 2024 18:28:00 -0600 Subject: [PATCH 1/2] seed.sql added id field for participants --- packages/adapter-postgres/seed.sql | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/adapter-postgres/seed.sql b/packages/adapter-postgres/seed.sql index 063c5fbe532..1cf1c088fc3 100644 --- a/packages/adapter-postgres/seed.sql +++ b/packages/adapter-postgres/seed.sql @@ -1,3 +1,9 @@ -INSERT INTO public.accounts (id, name, email, avatarUrl, details) VALUES ('00000000-0000-0000-0000-000000000000', 'Default Agent', 'default@agent.com', '', '{}'); -INSERT INTO public.rooms (id) VALUES ('00000000-0000-0000-0000-000000000000'); -INSERT INTO public.participants (userId, roomId) VALUES ('00000000-0000-0000-0000-000000000000', '00000000-0000-0000-0000-000000000000'); + +INSERT INTO public.accounts (id, name, email, "avatarUrl", details) +VALUES ('00000000-0000-0000-0000-000000000000', 'Default Agent', 'default@agent.com', '', '{}'::jsonb); + +INSERT INTO public.rooms (id) +VALUES ('00000000-0000-0000-0000-000000000000'); + +INSERT INTO public.participants (id, "userId", "roomId") +VALUES ('00000000-0000-0000-0000-000000000001', '00000000-0000-0000-0000-000000000000', '00000000-0000-0000-0000-000000000000'); From 21a1fb4de2602007e35fafe84d6b87576ef4f0e3 Mon Sep 17 00:00:00 2001 From: Ophiuchus Date: Sun, 10 Nov 2024 19:32:16 -0600 Subject: [PATCH 2/2] refactor embeddings to decouple getRemote getLocal for calls regardless of runtime --- packages/core/src/embedding.ts | 142 +++++++++++++++++---------------- 1 file changed, 72 insertions(+), 70 deletions(-) diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts index 50356a40bcc..9d322c5124f 100644 --- a/packages/core/src/embedding.ts +++ b/packages/core/src/embedding.ts @@ -22,98 +22,50 @@ function getRootPath() { return path.resolve(__dirname, ".."); } -/** - * Send a message to the OpenAI API for embedding. - * @param input The input to be embedded. - * @returns The embedding of the input. - */ -export async function embed(runtime: IAgentRuntime, input: string) { - // get the charcter, and handle by model type - const modelProvider = models[runtime.character.modelProvider]; - const embeddingModel = modelProvider.model.embedding; - - if ( - runtime.character.modelProvider !== ModelProviderName.OPENAI && - runtime.character.modelProvider !== ModelProviderName.OLLAMA && - !settings.USE_OPENAI_EMBEDDING - ) { - - // make sure to trim tokens to 8192 - const cacheDir = getRootPath() + "/cache/"; - - // if the cache directory doesn't exist, create it - if (!fs.existsSync(cacheDir)) { - fs.mkdirSync(cacheDir, { recursive: true }); - } - - const embeddingModel = await FlagEmbedding.init({ - cacheDir: cacheDir - }); - - const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini"); - - const embedding: number[] = await embeddingModel.queryEmbed(trimmedInput); - console.log("Embedding dimensions: ", embedding.length); - return embedding; - - // commented out the text generation service that uses llama - // const service = runtime.getService( - // ServiceType.TEXT_GENERATION - // ); - - // const instance = service?.getInstance(); - - // if (instance) { - // return await instance.getEmbeddingResponse(input); - // } - } - - // TODO: Fix retrieveCachedEmbedding - // Check if we already have the embedding in the lore - const cachedEmbedding = await retrieveCachedEmbedding(runtime, input); - if (cachedEmbedding) { - return cachedEmbedding; - } +interface EmbeddingOptions { + model: string; + endpoint: string; + apiKey?: string; + length?: number; + isOllama?: boolean; +} +async function getRemoteEmbedding(input: string, options: EmbeddingOptions): Promise { const requestOptions = { method: "POST", headers: { "Content-Type": "application/json", - // TODO: make this not hardcoded - // TODO: make this not hardcoded - ...((runtime.modelProvider !== ModelProviderName.OLLAMA || settings.USE_OPENAI_EMBEDDING) ? { - Authorization: `Bearer ${runtime.token}`, - } : {}), + ...(options.apiKey ? { + Authorization: `Bearer ${options.apiKey}`, + } : {}), }, body: JSON.stringify({ input, - model: embeddingModel, - length: 384, // we are squashing dimensions to 768 for openai, even thought the model supports 1536 - // -- this is ok for matryoshka embeddings but longterm, we might want to support 1536 + model: options.model, + length: options.length || 384, }), }; + try { const response = await fetch( - // TODO: make this not hardcoded - `${runtime.character.modelEndpointOverride || modelProvider.endpoint}${(runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING) ? "/v1" : ""}/embeddings`, + `${options.endpoint}${options.isOllama ? "/v1" : ""}/embeddings`, requestOptions ); if (!response.ok) { throw new Error( - "OpenAI API Error: " + - response.status + - " " + - response.statusText + "Embedding API Error: " + + response.status + + " " + + response.statusText ); } - interface OpenAIEmbeddingResponse { + interface EmbeddingResponse { data: Array<{ embedding: number[] }>; } - const data: OpenAIEmbeddingResponse = await response.json(); - + const data: EmbeddingResponse = await response.json(); return data?.data?.[0].embedding; } catch (e) { console.error(e); @@ -121,6 +73,55 @@ export async function embed(runtime: IAgentRuntime, input: string) { } } +/** + * Send a message to the OpenAI API for embedding. + * @param input The input to be embedded. + * @returns The embedding of the input. + */ +export async function embed(runtime: IAgentRuntime, input: string) { + const modelProvider = models[runtime.character.modelProvider]; + const embeddingModel = modelProvider.model.embedding; + + // Try local embedding first + if ( + runtime.character.modelProvider !== ModelProviderName.OPENAI && + runtime.character.modelProvider !== ModelProviderName.OLLAMA && + !settings.USE_OPENAI_EMBEDDING + ) { + return await getLocalEmbedding(input); + } + + // Check cache + const cachedEmbedding = await retrieveCachedEmbedding(runtime, input); + if (cachedEmbedding) { + return cachedEmbedding; + } + + // Get remote embedding + return await getRemoteEmbedding(input, { + model: embeddingModel, + endpoint: runtime.character.modelEndpointOverride || modelProvider.endpoint, + apiKey: runtime.token, + isOllama: runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING + }); +} + +async function getLocalEmbedding(input: string): Promise { + const cacheDir = getRootPath() + "/cache/"; + if (!fs.existsSync(cacheDir)) { + fs.mkdirSync(cacheDir, { recursive: true }); + } + + const embeddingModel = await FlagEmbedding.init({ + cacheDir: cacheDir + }); + + const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini"); + const embedding = await embeddingModel.queryEmbed(trimmedInput); + console.log("Embedding dimensions: ", embedding.length); + return embedding; +} + export async function retrieveCachedEmbedding( runtime: IAgentRuntime, input: string @@ -129,7 +130,7 @@ export async function retrieveCachedEmbedding( console.log("No input to retrieve cached embedding for"); return null; } - + const similaritySearchResult = await runtime.messageManager.getCachedEmbeddings(input); if (similaritySearchResult.length > 0) { @@ -137,3 +138,4 @@ export async function retrieveCachedEmbedding( } return null; } +