Skip to content

Commit

Permalink
Use messages API and system prompt
Browse files Browse the repository at this point in the history
Thanks to @9j7axvsLuF for some sample code in the original issue tomviner#6
  • Loading branch information
bderenzi authored Feb 7, 2024
1 parent b80e6fe commit 9fd017e
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions llm_claude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,76 @@

import click
import llm
from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic
from anthropic import Anthropic
from pydantic import Field, field_validator


@llm.hookimpl
def register_models(register):
# https://docs.anthropic.com/claude/reference/selecting-a-model
# Family Latest major version Latest full version
# Claude Instant claude-instant-1 claude-instant-1.1
# Claude claude-2 claude-2.0
register(Claude("claude-instant-1"), aliases=("claude-instant",))
register(Claude("claude-2"), aliases=("claude",))

# Registering models as per the latest naming conventions
register(Claude("claude-instant-1.2"), aliases=("claude-instant",))
register(Claude("claude-2.1"), aliases=("claude",))

class Claude(llm.Model):
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True

class Options(llm.Options):
max_tokens_to_sample: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=10_000,
max_tokens: Optional[int] = Field(
description="The maximum number of tokens for the model to generate",
default=4096, # Adjusted to the maximum allowed for claude-2.1
)

@field_validator("max_tokens_to_sample")
def validate_length(cls, max_tokens_to_sample):
if not (0 < max_tokens_to_sample <= 1_000_000):
raise ValueError("max_tokens_to_sample must be in range 1-1,000,000")
return max_tokens_to_sample
@field_validator("max_tokens")
def validate_max_tokens(cls, max_tokens):
if not (0 < max_tokens <= 4096): # Updated maximum limit
raise ValueError("max_tokens must be in range 1-4096 for claude-2.1")
return max_tokens

def __init__(self, model_id):
self.model_id = model_id

def generate_prompt_messages(self, prompt, conversation):
if conversation:
for response in conversation.responses:
yield self.build_prompt(response.prompt.prompt, response.text())

yield self.build_prompt(prompt)

def build_prompt(self, human, ai=""):
return f"{HUMAN_PROMPT} {human}{AI_PROMPT}{ai}"
# Generate a list of message dictionaries based on conversation history
messages = []
current_system = None
if conversation is not None:
for prev_response in conversation.responses:
if (prev_response.prompt.system and prev_response.prompt.system != current_system):
current_system = prev_response.prompt.system
messages.append({"role": "user", "content": prev_response.prompt.prompt})
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
current_system = prompt.system
messages.append({"role": "user", "content": prompt.prompt})

return messages, current_system

def execute(self, prompt, stream, response, conversation):
anthropic = Anthropic(api_key=self.get_key())

prompt_str = "".join(self.generate_prompt_messages(prompt.prompt, conversation))
messages, system_prompt = self.generate_prompt_messages(prompt, conversation)

completion = anthropic.completions.create(
model=self.model_id,
max_tokens_to_sample=prompt.options.max_tokens_to_sample,
prompt=prompt_str,
stream=stream,
)
if stream:
for comp in completion:
yield comp.completion
# Handling streaming responses
with anthropic.beta.messages.stream(
max_tokens=prompt.options.max_tokens,
messages=messages,
system=system_prompt,
model=self.model_id
) as stream_response:
for text in stream_response.text_stream:
yield text
else:
yield completion.completion
# Handling non-streaming response
message_response = anthropic.beta.messages.create(
model=self.model_id,
max_tokens=prompt.options.max_tokens,
messages=messages,
system=system_prompt
)
# Concatenating text from content blocks
yield "".join(content_block['text'] for content_block in message_response.content)

def __str__(self):
return "Anthropic: {}".format(self.model_id)

0 comments on commit 9fd017e

Please # to comment.