From 10c26de8ed6dd0e9d6849783af9aacb8903ab0a3 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 28 Aug 2024 09:32:09 -0700 Subject: [PATCH 1/3] Treat OpenAI message format as a message like --- .../language_models/tests/chat_models.test.ts | 58 +++++++++++++++++ langchain-core/src/messages/base.ts | 13 ++++ .../src/messages/tests/base_message.test.ts | 62 +++++++++++++++++++ langchain-core/src/messages/utils.ts | 48 +++++++++++++- 4 files changed, 178 insertions(+), 3 deletions(-) diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts index 940cc50802b2..70ff187243e8 100644 --- a/langchain-core/src/language_models/tests/chat_models.test.ts +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -27,6 +27,64 @@ test("Test ChatModel accepts object shorthand for messages", async () => { expect(response.content).toEqual("Hello there!"); }); +test("Test ChatModel accepts object with role for messages", async () => { + const model = new FakeChatModel({}); + const response = await model.invoke([ + { + role: "human", + content: "Hello there!!", + example: true, + }, + ]); + expect(response.content).toEqual("Hello there!!"); +}); + +test("Test ChatModel accepts several messages as objects with role", async () => { + const model = new FakeChatModel({}); + const response = await model.invoke([ + { + role: "system", + content: "You are an assistant.", + }, + { + role: "human", + content: [{ type: "text", text: "What is the weather in SF?" }], + example: true, + }, + { + role: "assistant", + content: "", + tool_calls: [ + { + id: "call_123", + function: { + name: "get_weather", + arguments: JSON.stringify({ location: "sf" }), + }, + type: "function", + }, + ], + }, + { + role: "tool", + content: "Pretty nice right now!", + tool_call_id: "call_123", + }, + ]); + expect(response.content).toEqual( + [ + "You are an assistant.", + JSON.stringify( + [{ type: "text", text: "What is the weather in SF?" }], + null, + 2 + ), + "", + "Pretty nice right now!", + ].join("\n") + ); +}); + test("Test ChatModel uses callbacks", async () => { const model = new FakeChatModel({}); let acc = ""; diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 8b74c7f158e1..eafbb9fe1fb4 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -440,12 +440,25 @@ export abstract class BaseMessageChunk extends BaseMessage { abstract concat(chunk: BaseMessageChunk): BaseMessageChunk; } +export type MessageFieldWithRole = { + role: MessageType | "user" | "assistant" | string; + content: MessageContent; + name?: string; +} & Record; + +export function _isMessageFieldWithRole( + x: BaseMessageLike +): x is MessageFieldWithRole { + return typeof (x as MessageFieldWithRole).role === "string"; +} + export type BaseMessageLike = | BaseMessage | ({ type: MessageType | "user" | "assistant" | "placeholder"; } & BaseMessageFields & Record) + | MessageFieldWithRole | [ StringWithAutocomplete< MessageType | "user" | "assistant" | "placeholder" diff --git a/langchain-core/src/messages/tests/base_message.test.ts b/langchain-core/src/messages/tests/base_message.test.ts index cc6926a1e6b6..21c1f81026f8 100644 --- a/langchain-core/src/messages/tests/base_message.test.ts +++ b/langchain-core/src/messages/tests/base_message.test.ts @@ -6,6 +6,8 @@ import { ToolMessage, ToolMessageChunk, AIMessageChunk, + coerceMessageLikeToMessage, + SystemMessage, } from "../index.js"; import { load } from "../../load/index.js"; @@ -334,3 +336,63 @@ describe("Complex AIMessageChunk concat", () => { ); }); }); + +describe("Message like coercion", () => { + it("Should convert OpenAI format messages", async () => { + const messages = [ + { + id: "foobar", + role: "system", + content: "You are an assistant.", + }, + { + role: "user", + content: [{ type: "text", text: "What is the weather in SF?" }], + }, + { + role: "assistant", + content: "", + tool_calls: [ + { + id: "call_123", + function: { + name: "get_weather", + arguments: JSON.stringify({ location: "sf" }), + }, + type: "function", + }, + ], + }, + { + role: "tool", + content: "Pretty nice right now!", + tool_call_id: "call_123", + }, + ].map(coerceMessageLikeToMessage); + expect(messages).toEqual([ + new SystemMessage({ + id: "foobar", + content: "You are an assistant.", + }), + new HumanMessage({ + content: [{ type: "text", text: "What is the weather in SF?" }], + }), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "call_123", + name: "get_weather", + args: { location: "sf" }, + type: "tool_call", + }, + ], + }), + new ToolMessage({ + name: undefined, + content: "Pretty nice right now!", + tool_call_id: "call_123", + }), + ]); + }); +}); diff --git a/langchain-core/src/messages/utils.ts b/langchain-core/src/messages/utils.ts index 4a34e03984ee..b9df034ee923 100644 --- a/langchain-core/src/messages/utils.ts +++ b/langchain-core/src/messages/utils.ts @@ -1,3 +1,4 @@ +import { _isToolCall } from "../tools/utils.js"; import { AIMessage, AIMessageChunk, AIMessageChunkFields } from "./ai.js"; import { BaseMessageLike, @@ -6,6 +7,7 @@ import { StoredMessage, StoredMessageV1, BaseMessageFields, + _isMessageFieldWithRole, } from "./base.js"; import { ChatMessage, @@ -19,16 +21,53 @@ import { } from "./function.js"; import { HumanMessage, HumanMessageChunk } from "./human.js"; import { SystemMessage, SystemMessageChunk } from "./system.js"; -import { ToolMessage, ToolMessageFieldsWithToolCallId } from "./tool.js"; +import { + ToolCall, + ToolMessage, + ToolMessageFieldsWithToolCallId, +} from "./tool.js"; + +function _coerceToolCall( + toolCall: ToolCall | Record +): ToolCall { + if (_isToolCall(toolCall)) { + return toolCall; + } else if ( + typeof toolCall.id === "string" && + toolCall.type === "function" && + typeof toolCall.function === "object" && + toolCall.function !== null && + "arguments" in toolCall.function && + typeof toolCall.function.arguments === "string" && + "name" in toolCall.function && + typeof toolCall.function.name === "string" + ) { + // Handle OpenAI tool call format + return { + id: toolCall.id, + args: JSON.parse(toolCall.function.arguments), + name: toolCall.function.name, + type: "tool_call", + }; + } else { + // TODO: Throw an error? + return toolCall as ToolCall; + } +} function _constructMessageFromParams( - params: BaseMessageFields & { type: string } + params: BaseMessageFields & { type: string } & Record ) { const { type, ...rest } = params; if (type === "human" || type === "user") { return new HumanMessage(rest); } else if (type === "ai" || type === "assistant") { - return new AIMessage(rest); + const { tool_calls: rawToolCalls, ...other } = rest; + if (!Array.isArray(rawToolCalls)) { + return new AIMessage(rest); + } + const tool_calls = rawToolCalls.map(_coerceToolCall); + return new AIMessage({ ...other, tool_calls }); } else if (type === "system") { return new SystemMessage(rest); } else if (type === "tool" && "tool_call_id" in rest) { @@ -56,6 +95,9 @@ export function coerceMessageLikeToMessage( if (Array.isArray(messageLike)) { const [type, content] = messageLike; return _constructMessageFromParams({ type, content }); + } else if (_isMessageFieldWithRole(messageLike)) { + const { role: type, ...rest } = messageLike; + return _constructMessageFromParams({ ...rest, type }); } else { return _constructMessageFromParams(messageLike); } From 28f77d5165a642291acc73c099d098ba8a8aadfb Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 28 Aug 2024 09:37:31 -0700 Subject: [PATCH 2/3] Use string with autocomplete --- langchain-core/src/messages/base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index eafbb9fe1fb4..3cbf5646f8ae 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -441,7 +441,7 @@ export abstract class BaseMessageChunk extends BaseMessage { } export type MessageFieldWithRole = { - role: MessageType | "user" | "assistant" | string; + role: StringWithAutocomplete<"user" | "assistant" | MessageType>; content: MessageContent; name?: string; } & Record; From 54846969ff4afed7f58a73ef39a83f602be87877 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 28 Aug 2024 13:47:19 -0700 Subject: [PATCH 3/3] Update test to be in line with Python --- .../src/messages/tests/base_message.test.ts | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/langchain-core/src/messages/tests/base_message.test.ts b/langchain-core/src/messages/tests/base_message.test.ts index 21c1f81026f8..0e6883c89dc0 100644 --- a/langchain-core/src/messages/tests/base_message.test.ts +++ b/langchain-core/src/messages/tests/base_message.test.ts @@ -343,21 +343,21 @@ describe("Message like coercion", () => { { id: "foobar", role: "system", - content: "You are an assistant.", + content: "6", }, { role: "user", - content: [{ type: "text", text: "What is the weather in SF?" }], + content: [{ type: "image_url", image_url: { url: "7.1" } }], }, { role: "assistant", - content: "", + content: [{ type: "text", text: "8.1" }], tool_calls: [ { - id: "call_123", + id: "8.5", function: { - name: "get_weather", - arguments: JSON.stringify({ location: "sf" }), + name: "8.4", + arguments: JSON.stringify({ "8.2": "8.3" }), }, type: "function", }, @@ -365,33 +365,33 @@ describe("Message like coercion", () => { }, { role: "tool", - content: "Pretty nice right now!", - tool_call_id: "call_123", + content: "10.2", + tool_call_id: "10.2", }, ].map(coerceMessageLikeToMessage); expect(messages).toEqual([ new SystemMessage({ id: "foobar", - content: "You are an assistant.", + content: "6", }), new HumanMessage({ - content: [{ type: "text", text: "What is the weather in SF?" }], + content: [{ type: "image_url", image_url: { url: "7.1" } }], }), new AIMessage({ - content: "", + content: [{ type: "text", text: "8.1" }], tool_calls: [ { - id: "call_123", - name: "get_weather", - args: { location: "sf" }, + id: "8.5", + name: "8.4", + args: { "8.2": "8.3" }, type: "tool_call", }, ], }), new ToolMessage({ name: undefined, - content: "Pretty nice right now!", - tool_call_id: "call_123", + content: "10.2", + tool_call_id: "10.2", }), ]); });