Skip to content

Commit

Permalink
feat(openai): Properly pass through max_completion_tokens (#7683)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Feb 11, 2025
1 parent 037dd91 commit 92edca1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
15 changes: 12 additions & 3 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export function _convertMessagesToOpenAIParams(
// TODO: Function messages do not support array content, fix cast
return messages.flatMap((message) => {
let role = messageToOpenAIRole(message);
if (role === "system" && model?.startsWith("o1")) {
if (role === "system" && isReasoningModel(model)) {
role = "developer";
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -224,6 +224,10 @@ function _convertChatOpenAIToolTypeToOpenAITool(
return _convertToOpenAITool(tool, fields);
}

function isReasoningModel(model?: string) {
return model?.startsWith("o1") || model?.startsWith("o3");
}

// TODO: Use the base structured output options param in next breaking release.
export interface ChatOpenAIStructuredOutputMethodOptions<
IncludeRaw extends boolean
Expand Down Expand Up @@ -1027,7 +1031,6 @@ export class ChatOpenAI<
this.topP = fields?.topP ?? this.topP;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.maxTokens = fields?.maxTokens;
this.logprobs = fields?.logprobs;
this.topLogprobs = fields?.topLogprobs;
this.n = fields?.n ?? this.n;
Expand All @@ -1039,6 +1042,7 @@ export class ChatOpenAI<
this.audio = fields?.audio;
this.modalities = fields?.modalities;
this.reasoningEffort = fields?.reasoningEffort;
this.maxTokens = fields?.maxCompletionTokens ?? fields?.maxTokens;

if (this.model === "o1") {
this.disableStreaming = true;
Expand Down Expand Up @@ -1151,7 +1155,6 @@ export class ChatOpenAI<
top_p: this.topP,
frequency_penalty: this.frequencyPenalty,
presence_penalty: this.presencePenalty,
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
logprobs: this.logprobs,
top_logprobs: this.topLogprobs,
n: this.n,
Expand Down Expand Up @@ -1187,6 +1190,12 @@ export class ChatOpenAI<
if (reasoningEffort !== undefined) {
params.reasoning_effort = reasoningEffort;
}
if (isReasoningModel(params.model)) {
params.max_completion_tokens =
this.maxTokens === -1 ? undefined : this.maxTokens;
} else {
params.max_tokens = this.maxTokens === -1 ? undefined : this.maxTokens;
}
return params;
}

Expand Down
20 changes: 20 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,26 @@ test("Allows developer messages with o1", async () => {
expect(res.content).toEqual("testing");
});

test("Works with maxCompletionTokens with o3", async () => {
const model = new ChatOpenAI({
model: "o3-mini",
reasoningEffort: "low",
maxCompletionTokens: 100,
});
const res = await model.invoke([
{
role: "system",
content: `Always respond only with the word "testing"`,
},
{
role: "user",
content: "hi",
},
]);
console.log(res);
expect(res.content).toEqual("testing");
});

test.skip("Allow overriding", async () => {
class ChatDeepSeek extends ChatOpenAI {
protected override _convertOpenAIDeltaToBaseMessageChunk(
Expand Down
7 changes: 7 additions & 0 deletions libs/langchain-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ export declare interface OpenAIBaseInput {
*/
maxTokens?: number;

/**
* Maximum number of tokens to generate in the completion. -1 returns as many
* tokens as possible given the prompt and the model's maximum context size.
* Alias for `maxTokens` for reasoning models.
*/
maxCompletionTokens?: number;

/** Total probability mass of tokens to consider at each step */
topP: number;

Expand Down

0 comments on commit 92edca1

Please # to comment.