Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: add OpenAI's new structured output API #180

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# How many chat completion choices to generate for each input message.
field :n, :integer, default: 1
field :json_response, :boolean, default: false
field :json_schema, :map, default: nil
field :stream, :boolean, default: false
field :max_tokens, :integer, default: nil
# Options for streaming response. Only set this when you set `stream: true`
Expand Down Expand Up @@ -153,6 +154,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:stream,
:receive_timeout,
:json_response,
:json_schema,
:max_tokens,
:stream_options,
:user,
Expand Down Expand Up @@ -263,11 +265,20 @@ defmodule LangChain.ChatModels.ChatOpenAI do
%{"include_usage" => Map.get(data, :include_usage, Map.get(data, "include_usage"))}
end

defp set_response_format(%ChatOpenAI{json_response: true}),
do: %{"type" => "json_object"}
defp set_response_format(%ChatOpenAI{json_response: true, json_schema: json_schema}) when not is_nil(json_schema) do
%{
"type" => "json_schema",
"json_schema" => json_schema
}
end

defp set_response_format(%ChatOpenAI{json_response: false}),
do: %{"type" => "text"}
defp set_response_format(%ChatOpenAI{json_response: true}) do
%{"type" => "json_object"}
end

defp set_response_format(%ChatOpenAI{json_response: false}) do
%{"type" => "text"}
end

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
Expand Down Expand Up @@ -908,6 +919,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:seed,
:n,
:json_response,
:json_schema,
:stream,
:max_tokens,
:stream_options
Expand Down
30 changes: 16 additions & 14 deletions test/chains/data_extraction_chain_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do
FunctionParam.new!(%{name: "person_name", type: :string}),
FunctionParam.new!(%{name: "person_age", type: :number}),
FunctionParam.new!(%{name: "person_hair_color", type: :string}),
FunctionParam.new!(%{name: "dog_name", type: :string}),
FunctionParam.new!(%{name: "dog_breed", type: :string})
FunctionParam.new!(%{name: "pet_dog_name", type: :string}),
FunctionParam.new!(%{name: "pet_dog_breed", type: :string})
]
|> FunctionParam.to_parameters_schema()

Expand All @@ -31,8 +31,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do
items: %{
"type" => "object",
"properties" => %{
"dog_breed" => %{"type" => "string"},
"dog_name" => %{"type" => "string"},
"pet_dog_breed" => %{"type" => "string"},
"pet_dog_name" => %{"type" => "string"},
"person_age" => %{"type" => "number"},
"person_hair_color" => %{"type" => "string"},
"person_name" => %{"type" => "string"}
Expand All @@ -55,32 +55,34 @@ defmodule LangChain.Chains.DataExtractionChainTest do
FunctionParam.new!(%{name: "person_name", type: :string}),
FunctionParam.new!(%{name: "person_age", type: :number}),
FunctionParam.new!(%{name: "person_hair_color", type: :string}),
FunctionParam.new!(%{name: "dog_name", type: :string}),
FunctionParam.new!(%{name: "dog_breed", type: :string})
FunctionParam.new!(%{name: "pet_dog_name", type: :string}),
FunctionParam.new!(%{name: "pet_dog_breed", type: :string})
]
|> FunctionParam.to_parameters_schema()

# Model setup - specify the model and seed
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o", temperature: 0, seed: 0, stream: false})
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-mini-2024-07-18", temperature: 0, seed: 0, stream: false})

# run the chain, chain.run(prompt to extract data from)
data_prompt =
"Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him.
Claudia is a brunette and Alex is blonde. Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information."
data_prompt = """
Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him.
Claudia is a brunette and Alex is blonde.
Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information.
"""

{:ok, result} = DataExtractionChain.run(chat, schema_parameters, data_prompt, verbose: true)

assert result == [
%{
"dog_breed" => "labrador",
"dog_name" => "Frosty",
"pet_dog_breed" => "labrador",
"pet_dog_name" => "Frosty",
"person_age" => nil,
"person_hair_color" => "blonde",
"person_name" => "Alex"
},
%{
"dog_breed" => nil,
"dog_name" => nil,
"pet_dog_breed" => nil,
"pet_dog_name" => nil,
"person_age" => nil,
"person_hair_color" => "brunette",
"person_name" => "Claudia"
Expand Down
100 changes: 96 additions & 4 deletions test/chat_models/chat_open_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
alias LangChain.Message.ToolCall
alias LangChain.Message.ToolResult

@test_model "gpt-3.5-turbo"
@test_model "gpt-4o-mini-2024-07-18"
@gpt4 "gpt-4-1106-preview"

defp hello_world(_args, _context) do
Expand Down Expand Up @@ -73,6 +73,25 @@ defmodule LangChain.ChatModels.ChatOpenAITest do

assert model.endpoint == override_url
end

test "supports setting json_response and json_schema" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} = ChatOpenAI.new(%{
"model" => @test_model,
"json_response" => true,
"json_schema" => json_schema
})

assert openai.json_response == true
assert openai.json_schema == json_schema
end
end

describe "for_api/3" do
Expand Down Expand Up @@ -108,6 +127,34 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
assert data.response_format == %{"type" => "json_object"}
end

test "generates a map for an API call with JSON response and schema" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} =
ChatOpenAI.new(%{
"model" => @test_model,
"temperature" => 1,
"frequency_penalty" => 0.5,
"json_response" => true,
"json_schema" => json_schema
})

data = ChatOpenAI.for_api(openai, [], [])
assert data.model == @test_model
assert data.temperature == 1
assert data.frequency_penalty == 0.5
assert data.response_format == %{
"type" => "json_schema",
"json_schema" => json_schema
}
end

test "generates a map for an API call with max_tokens set" do
{:ok, openai} =
ChatOpenAI.new(%{
Expand Down Expand Up @@ -419,7 +466,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
"description" => nil,
"enum" => ["yellow", "red", "green"],
"type" => "string"
}
}
},
"required" => ["p1"]
}
Expand Down Expand Up @@ -789,7 +836,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
@tag live_call: true, live_open_ai: true
test "handles when request is too large" do
{:ok, chat} =
ChatOpenAI.new(%{model: "gpt-3.5-turbo-0301", seed: 0, stream: false, temperature: 1})
ChatOpenAI.new(%{model: "gpt-4-0613", seed: 0, stream: false, temperature: 1})

{:error, reason} = ChatOpenAI.call(chat, [too_large_user_request()])
assert reason =~ "maximum context length"
Expand Down Expand Up @@ -1330,7 +1377,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
@tag live_call: true, live_open_ai: true
test "supports multi-modal user message with image prompt" do
# https://platform.openai.com/docs/guides/vision
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4-vision-preview", seed: 0})
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-2024-08-06", seed: 0})

url =
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
Expand Down Expand Up @@ -1891,8 +1938,53 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
"stream_options" => %{"include_usage" => true},
"temperature" => 0.0,
"version" => 1,
"json_schema" => nil,
"module" => "Elixir.LangChain.ChatModels.ChatOpenAI"
}
end
end

describe "set_response_format/1" do
test "generates a map for an API call with text format when json_response is false" do
{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: false
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{"type" => "text"}
end

test "generates a map for an API call with json_object format when json_response is true and no schema" do
{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: true
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{"type" => "json_object"}
end

test "generates a map for an API call with json_schema format when json_response is true and schema is provided" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: true,
json_schema: json_schema
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{
"type" => "json_schema",
"json_schema" => json_schema
}
end
end
end
2 changes: 1 addition & 1 deletion test/message_delta_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ defmodule LangChain.MessageDeltaTest do
status: :incomplete,
type: :function,
call_id: "toolu_123",
name: "get_codeget_codeget_codeget_codeget_code",
name: "get_code",
arguments: "{\"code\": \"def my_function(x):\n return x + 1\"}",
index: 1
}
Expand Down
2 changes: 1 addition & 1 deletion test/support/fixtures.ex
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ defmodule LangChain.Fixtures do
end

def too_large_user_request() do
Message.new_user!("Analyze the following text: \n\n" <> text_chunks(8))
Message.new_user!("Analyze the following text: \n\n" <> text_chunks(16))
end

def results_in_too_long_response() do
Expand Down