diff --git a/lib/ruby_llm/provider.rb b/lib/ruby_llm/provider.rb index 5540e907..67e40a6b 100644 --- a/lib/ruby_llm/provider.rb +++ b/lib/ruby_llm/provider.rb @@ -7,7 +7,8 @@ module RubyLLM module Provider # Common functionality for all LLM providers. Implements the core provider # interface so specific providers only need to implement a few key methods. - module Methods # rubocop:disable Metrics/ModuleLength + # rubocop:disable Metrics/ModuleLength + module Methods extend Streaming def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength @@ -40,14 +41,14 @@ def list_models def embed(text, model:) payload = render_embedding_payload text, model: model - response = post embedding_url, payload + response = post embedding_url, payload, model_id: model parse_embedding_response response end def paint(prompt, model:, size:) payload = render_image_payload(prompt, model:, size:) - response = post(images_url, payload) + response = post(images_url, payload, model_id: model) parse_image_response(response) end @@ -78,13 +79,21 @@ def ensure_configured! end def sync_response(payload) - response = post completion_url, payload + model_id = payload[:model] + response = post completion_url, payload, model_id: model_id parse_completion_response response end - def post(url, payload) + def post(url, payload, model_id: nil) + request_headers = headers + + if model_id && capabilities.respond_to?(:additional_headers_for_model) + additional_headers = capabilities.additional_headers_for_model(model_id) + request_headers = request_headers.merge(additional_headers) unless additional_headers.empty? + end + connection.post url, payload do |req| - req.headers.merge! headers + req.headers.merge! request_headers yield req if block_given? end end @@ -190,5 +199,6 @@ def configured_providers providers.select { |_name, provider| provider.configured? }.values end end + # rubocop:enable Metrics/ModuleLength end end diff --git a/lib/ruby_llm/providers/anthropic/capabilities.rb b/lib/ruby_llm/providers/anthropic/capabilities.rb index 4e07afec..a835b4ad 100644 --- a/lib/ruby_llm/providers/anthropic/capabilities.rb +++ b/lib/ruby_llm/providers/anthropic/capabilities.rb @@ -68,6 +68,18 @@ def supports_extended_thinking?(model_id) model_id.match?(/claude-3-7-sonnet/) end + # Returns additional request headers for a specific model + # @param model_id [String] the model identifier + # @return [Hash] additional headers to include in the request + def additional_headers_for_model(model_id) + case model_id + when 'claude-3-7-sonnet-20250219' + { 'anthropic-beta' => 'output-128k-2025-02-19' } + else + {} + end + end + # Determines the model family for a given model ID # @param model_id [String] the model identifier # @return [Symbol] the model family identifier diff --git a/lib/ruby_llm/providers/bedrock.rb b/lib/ruby_llm/providers/bedrock.rb index 7db44a9a..aff0bdcc 100644 --- a/lib/ruby_llm/providers/bedrock.rb +++ b/lib/ruby_llm/providers/bedrock.rb @@ -24,7 +24,7 @@ def api_base @api_base ||= "https://bedrock-runtime.#{RubyLLM.config.bedrock_region}.amazonaws.com" end - def post(url, payload) + def post(url, payload, model_id: nil) # rubocop:disable Lint/UnusedMethodArgument signature = sign_request("#{connection.url_prefix}#{url}", payload:) connection.post url, payload do |req| req.headers.merge! build_headers(signature.headers, streaming: block_given?) diff --git a/lib/ruby_llm/streaming.rb b/lib/ruby_llm/streaming.rb index 442aa014..43cf7e44 100644 --- a/lib/ruby_llm/streaming.rb +++ b/lib/ruby_llm/streaming.rb @@ -10,8 +10,9 @@ module Streaming def stream_response(payload, &block) accumulator = StreamAccumulator.new + model_id = payload[:model] - post stream_url, payload do |req| + post stream_url, payload, model_id: model_id do |req| req.options.on_data = handle_stream do |chunk| accumulator.add chunk block.call chunk diff --git a/spec/ruby_llm/providers/anthropic/capabilities_spec.rb b/spec/ruby_llm/providers/anthropic/capabilities_spec.rb new file mode 100644 index 00000000..c5d239b6 --- /dev/null +++ b/spec/ruby_llm/providers/anthropic/capabilities_spec.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Providers::Anthropic::Capabilities do + describe '.additional_headers_for_model' do + it 'returns the beta header for claude-3-7-sonnet-20250219' do + result = described_class.additional_headers_for_model('claude-3-7-sonnet-20250219') + expect(result).to eq('anthropic-beta' => 'output-128k-2025-02-19') + end + + it 'returns an empty hash for other models' do + other_models = ['claude-3-5-sonnet-20241022', 'claude-3-haiku', 'claude-2'] + + other_models.each do |model| + result = described_class.additional_headers_for_model(model) + expect(result).to eq({}) + end + end + end +end diff --git a/spec/ruby_llm/providers/anthropic/headers_spec.rb b/spec/ruby_llm/providers/anthropic/headers_spec.rb new file mode 100644 index 00000000..13366178 --- /dev/null +++ b/spec/ruby_llm/providers/anthropic/headers_spec.rb @@ -0,0 +1,116 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe 'Anthropic API Request Headers' do # rubocop:disable RSpec/DescribeClass + include_context 'with configured RubyLLM' + + before do + WebMock.disable_net_connect!(allow_localhost: true) + end + + after do + WebMock.allow_net_connect! + end + + it 'includes the beta header for claude-3-7-sonnet-20250219' do # rubocop:disable RSpec/ExampleLength + # Setup the expected request with the beta header + stub_request(:post, 'https://api.anthropic.com/v1/messages') + .with( + headers: { + 'Anthropic-Beta' => 'output-128k-2025-02-19' + } + ) + .to_return( + status: 200, + body: { + id: 'msg_123', + model: 'claude-3-7-sonnet-20250219', + type: 'message', + role: 'assistant', + content: [{ type: 'text', text: 'Hello!' }], + usage: { input_tokens: 10, output_tokens: 20 } + }.to_json, + headers: { 'Content-Type' => 'application/json' } + ) + + # Make a request with the specific model + chat = RubyLLM.chat(model: 'claude-3-7-sonnet-20250219') + chat.ask('Hello') + + # Verify that the request was made with the expected headers + expect( + a_request(:post, 'https://api.anthropic.com/v1/messages') + .with(headers: { 'Anthropic-Beta' => 'output-128k-2025-02-19' }) + ).to have_been_made + end + + it 'does not include the beta header for other Claude models' do # rubocop:disable RSpec/ExampleLength + # Setup the expected request without the beta header + stub_request(:post, 'https://api.anthropic.com/v1/messages') + .with { |request| !request.headers.key?('Anthropic-Beta') } + .to_return( + status: 200, + body: { + id: 'msg_456', + model: 'claude-3-5-sonnet-20241022', + type: 'message', + role: 'assistant', + content: [{ type: 'text', text: 'Hello!' }], + usage: { input_tokens: 10, output_tokens: 20 } + }.to_json, + headers: { 'Content-Type' => 'application/json' } + ) + + # Make a request with a different model + chat = RubyLLM.chat(model: 'claude-3-5-sonnet-20241022') + chat.ask('Hello') + + # Verify that the request was made without the beta header + expect( + a_request(:post, 'https://api.anthropic.com/v1/messages') + .with { |request| !request.headers.key?('Anthropic-Beta') } + ).to have_been_made + end + + it 'includes the beta header in streaming responses for claude-3-7-sonnet-20250219' do # rubocop:disable RSpec/ExampleLength + streaming_body = <<~STREAM_DATA + event: content_block_delta + data: {"type":"content_block_delta","delta":{"type":"text","text":"Hello"}} + + event: content_block_delta + data: {"type":"content_block_delta","delta":{"type":"text","text":"!"}} + + event: message_stop + data: {} + STREAM_DATA + + # Setup the expected streaming request with the beta header + stub_request(:post, 'https://api.anthropic.com/v1/messages') + .with( + headers: { + 'Anthropic-Beta' => 'output-128k-2025-02-19' + }, + body: hash_including('stream' => true) + ) + .to_return( + status: 200, + body: streaming_body, + headers: { 'Content-Type' => 'text/event-stream' } + ) + + # Make a streaming request with the specific model + chat = RubyLLM.chat(model: 'claude-3-7-sonnet-20250219') + chunks = [] + chat.ask('Hello') { |chunk| chunks << chunk } + + # Verify that the streaming request was made with the expected headers + expect( + a_request(:post, 'https://api.anthropic.com/v1/messages') + .with( + headers: { 'Anthropic-Beta' => 'output-128k-2025-02-19' }, + body: hash_including('stream' => true) + ) + ).to have_been_made + end +end