Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: function calling in litellm #93

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions bolna/agent_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ async def __filler_classification_task(self, message):
should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] == True
filler = random.choice((FILLER_DICT[filler_class]))
await self._handle_llm_output(next_step, filler, should_bypass_synth, new_meta_info, is_filler = True)

async def __execute_function_call(self, url, method, param, api_token, model_args, meta_info, next_step, called_fun, **resp):
self.check_if_user_online = False
async def __execute_function_call(self, tool_call_id, is_lite_llm, url, method, param, api_token, model_args, meta_info, next_step, called_fun, **resp):
self.check_if_user_online = False

if called_fun.startswith("transfer_call"):
logger.info(f"Transfer call function called param {param}. First sleeping for 2 seconds to make sure we're done speaking the filler")
Expand Down Expand Up @@ -1047,6 +1047,8 @@ async def __execute_function_call(self, url, method, param, api_token, model_arg

content = FUNCTION_CALL_PROMPT.format(called_fun, method, set_response_prompt)
model_args["messages"].append({"role": "system","content": content})
if(is_lite_llm):
model_args["messages"].append({"role": "tool", "tool_call_id": tool_call_id, "content": function_response})
logger.info(f"Logging function call parameters ")
convert_to_request_log(function_response, meta_info , None, "function_call", direction = "response", is_cached= False, run_id = self.run_id)

Expand Down Expand Up @@ -1105,11 +1107,11 @@ async def __do_llm_generation(self, messages, meta_info, next_step, should_bypas

async for llm_message in self.tools['llm_agent'].generate(messages, synthesize=synthesize, meta_info=meta_info):
logger.info(f"llm_message {llm_message}")
data, end_of_llm_stream, latency, trigger_function_call = llm_message
data, end_of_llm_stream, latency, trigger_function_call, tool_call_id, is_lite_llm = llm_message

if trigger_function_call:
logger.info(f"Triggering function call for {data}")
self.llm_task = asyncio.create_task(self.__execute_function_call(next_step = next_step, **data))
self.llm_task = asyncio.create_task(self.__execute_function_call(tool_call_id = tool_call_id, is_lite_llm = is_lite_llm, next_step = next_step, **data))
return

if latency and (len(self.llm_latencies) == 0 or self.llm_latencies[-1] != latency):
Expand Down
63 changes: 51 additions & 12 deletions bolna/helpers/function_calling_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,80 @@

logger = configure_logger(__name__)

async def trigger_api(url, method, param, api_token, meta_info, run_id, **kwargs):
async def trigger_api(url, method, param, api_token, meta_info, run_id, header=None, **kwargs):
try:
request_body, api_params = None, None

# Replace placeholders in the URL dynamically
if "%(" in url and ")s" in url:
try:
url = url % kwargs
logger.info(f"Processed URL with dynamic parameters: {url}")
except KeyError as e:
message = f"ERROR: Missing URL parameter: {e}"
logger.error(message)
return message

# Handle request parameters dynamically
if param:
code = compile(param % kwargs, "<string>", "exec")
exec(code, globals(), kwargs)
request_body = param % kwargs
api_params = json.loads(request_body)
try:
code = compile(param % kwargs, "<string>", "exec")
exec(code, globals(), kwargs)
request_body = param % kwargs
api_params = json.loads(request_body)

logger.info(f"Params {param % kwargs} \n {type(request_body)} \n {param} \n {kwargs} \n\n {request_body}")
logger.info(f"Params {param % kwargs} \n {type(request_body)} \n {param} \n {kwargs} \n\n {request_body}")
except Exception as e:
logger.error(f"Error processing request parameters: {e}")
return f"ERROR: Invalid parameters: {e}"
else:
logger.info(f"Params {param} \n {type(request_body)} \n {param} \n {kwargs} \n\n {request_body}")

headers = {'Content-Type': 'application/json'}
# Default headers setup
if(header):
headers = json.loads(header)
else:
headers = {'Content-Type': 'application/json'}
if api_token:
headers['Authorization'] = api_token

convert_to_request_log(request_body, meta_info , None, "function_call", direction="request", is_cached=False, run_id=run_id)
convert_to_request_log(request_body, meta_info, None, "function_call", direction="request", is_cached=False, run_id=run_id)

logger.info(f"Sleeping for 700 ms to make sure that we do not send the same message multiple times")
logger.info("Sleeping for 700 ms to make sure that we do not send the same message multiple times")
await asyncio.sleep(0.7)

async with aiohttp.ClientSession() as session:
# Handle different HTTP methods
if method.lower() == "get":
logger.info(f"Sending request {request_body}, {url}, {headers}")
logger.info(f"Sending GET request: {url}, Params: {api_params}, Headers: {headers}")
async with session.get(url, params=api_params, headers=headers) as response:
response_text = await response.text()
logger.info(f"Response from the server: {response_text}")

elif method.lower() == "post":
logger.info(f"Sending request {api_params}, {url}, {headers}")
logger.info(f"Sending POST request: {url}, Data: {api_params}, Headers: {headers}")
async with session.post(url, json=api_params, headers=headers) as response:
response_text = await response.text()
logger.info(f"Response from the server: {response_text}")


elif method.lower() == "put":
logger.info(f"Sending PUT request: {url}, Data: {api_params}, Headers: {headers}")
async with session.put(url, json=api_params, headers=headers) as response:
response_text = await response.text()
logger.info(f"Response from the server: {response_text}")

elif method.lower() == "delete":
logger.info(f"Sending DELETE request: {url}, Data: {api_params}, Headers: {headers}")
async with session.delete(url, json=api_params, headers=headers) as response:
response_text = await response.text()
logger.info(f"Response from the server: {response_text}")

else:
logger.error(f"Unsupported HTTP method: {method}")
response_text = f"Unsupported HTTP method: {method}"

return response_text

except Exception as e:
message = f"ERROR CALLING API: There was an error calling the API: {e}"
logger.error(message)
Expand Down
137 changes: 121 additions & 16 deletions bolna/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import litellm
from dotenv import load_dotenv
from .llm import BaseLLM
from bolna.constants import DEFAULT_LANGUAGE_CODE
from bolna.helpers.utils import json_to_pydantic_schema
from bolna.constants import DEFAULT_LANGUAGE_CODE, PRE_FUNCTION_CALL_MESSAGE, TRANSFERING_CALL_FILLER
from bolna.helpers.utils import json_to_pydantic_schema, convert_to_request_log
from bolna.helpers.logger_config import configure_logger
import time
import json

logger = configure_logger(__name__)
load_dotenv()
Expand All @@ -20,12 +21,32 @@ def __init__(self, model, max_tokens=30, buffer_size=40, temperature=0.0, langua
self.model = kwargs['azure_model']

self.started_streaming = False

self.language = language

# Function calling setup
self.custom_tools = kwargs.get("api_tools", None)
if self.custom_tools is not None:
self.trigger_function_call = True
self.api_params = self.custom_tools['tools_params']
logger.info(f"Function dict {self.api_params}")
# Convert tools to LiteLLM format
self.tools = [
{
"type": "function",
"function": tool
} for tool in self.custom_tools['tools']
]
else:
self.trigger_function_call = False

self.gave_out_prefunction_call_message = False
self.run_id = kwargs.get("run_id", None)

self.model_args = {"max_tokens": max_tokens, "temperature": temperature, "model": self.model}
self.api_key = kwargs.get("llm_key", os.getenv('LITELLM_MODEL_API_KEY'))
self.api_base = kwargs.get("base_url", os.getenv('LITELLM_MODEL_API_BASE'))
self.api_version = kwargs.get("api_version", os.getenv('LITELLM_MODEL_API_VERSION'))

if self.api_key:
self.model_args["api_key"] = self.api_key
if self.api_base:
Expand All @@ -41,43 +62,122 @@ def __init__(self, model, max_tokens=30, buffer_size=40, temperature=0.0, langua
if "api_version" in kwargs:
self.model_args["api_version"] = kwargs["api_version"]

async def generate_stream(self, messages, synthesize=True, meta_info = None):
answer, buffer = "", ""
async def generate_stream(self, messages, synthesize=True, request_json=False, meta_info=None):
answer, buffer, resp, called_fun, i = "", "", "", "", 0
model_args = self.model_args.copy()
model_args["messages"] = messages
model_args["stream"] = True
model_args["stop"] = ["User:"]
model_args["user"] = f"{self.run_id}#{meta_info['turn_id']}" if meta_info else None

tools = []
if self.trigger_function_call:
if type(self.tools) is str:
tools = json.loads(self.tools)
else:
tools = self.tools
model_args["tools"] = tools
model_args["tool_choice"] = "auto"

logger.info(f"request to model: {self.model}: {messages} and model args {model_args}")
latency = False
start_time = time.time()
textual_response = False

async for chunk in await litellm.acompletion(**model_args):
if not self.started_streaming:
first_chunk_time = time.time()
latency = first_chunk_time - start_time
logger.info(f"LLM Latency: {latency:.2f} s")
self.started_streaming = True
if (text_chunk := chunk['choices'][0]['delta'].content) and not chunk['choices'][0].finish_reason:

delta = chunk['choices'][0]['delta']

if self.trigger_function_call and hasattr(delta, 'tool_calls') and delta.tool_calls:
tool_call = delta.tool_calls[0]

if hasattr(tool_call, 'function'):
function_data = tool_call.function
logger.info(f"function_data: {function_data}")

if hasattr(function_data, 'name') and function_data.name:
logger.info(f"Should do a function call {function_data.name}")
called_fun = str(function_data.name)
i = [i for i in range(len(self.tools)) if called_fun == self.tools[i]["function"]["name"]][0]

if not self.gave_out_prefunction_call_message and not textual_response:
filler = PRE_FUNCTION_CALL_MESSAGE if not called_fun.startswith("transfer_call") else TRANSFERING_CALL_FILLER.get(self.language, DEFAULT_LANGUAGE_CODE)
yield filler, True, latency, False, None, True
self.gave_out_prefunction_call_message = True

if len(buffer) > 0:
yield buffer, True, latency, False, None, True
buffer = ''
logger.info(f"Response from LLM {resp}")

if buffer != '':
yield buffer, False, latency, False, None, True
buffer = ''
if hasattr(function_data, 'arguments') and function_data.arguments:
resp += function_data.arguments

elif hasattr(delta, 'content') and delta.content:
text_chunk = delta.content
textual_response = True
answer += text_chunk
buffer += text_chunk

if len(buffer) >= self.buffer_size and synthesize:
text = ' '.join(buffer.split(" ")[:-1])
buffer_words = buffer.split(" ")
text = ' '.join(buffer_words[:-1])

if not self.started_streaming:
self.started_streaming = True
yield text, False, latency, False, None, True
buffer = buffer_words[-1]

if self.trigger_function_call and called_fun and called_fun in self.api_params:
func_dict = self.api_params[called_fun]
logger.info(f"Payload to send {resp} func_dict {func_dict}")
self.gave_out_prefunction_call_message = False

url = func_dict['url']
method = func_dict['method']
param = func_dict['param']
api_token = func_dict['api_token']
header = func_dict['header'] or None
api_call_return = {
"url": url,
"method": None if method is None else method.lower(),
"param": param,
"api_token": api_token,
"header": header,
"model_args": model_args,
"meta_info": meta_info,
"called_fun": called_fun,
}

if synthesize:
if not self.started_streaming:
self.started_streaming = True
yield text, False, latency, False
buffer = buffer.split(" ")[-1]
tool_params = tools[i]["function"]["parameters"]
all_required_keys = tool_params["properties"].keys() and tool_params.get("required", [])

if tool_params is not None and (all(key in resp for key in all_required_keys)):
logger.info(f"Function call parameters: {resp}")
convert_to_request_log(resp, meta_info, self.model, "llm", direction="response", is_cached=False, run_id=self.run_id)
resp = json.loads(resp)
api_call_return = {**api_call_return, **resp}
else:
api_call_return['resp'] = None
logger.info(f"api call return: {api_call_return}")
yield api_call_return, False, latency, True, tool_call.id, True

if synthesize:
if buffer != "":
yield buffer, True, latency, False
yield buffer, True, latency, False, None, True
else:
yield answer, True, latency, False
yield answer, True, latency, False, None, True
self.started_streaming = False
logger.info(f"Time to generate response {time.time() - start_time} {answer}")

async def generate(self, messages, stream=False, request_json=False, meta_info = None):
async def generate(self, messages, stream=False, request_json=False, meta_info=None):
text = ""
model_args = self.model_args.copy()
model_args["model"] = self.model
Expand All @@ -89,6 +189,11 @@ async def generate(self, messages, stream=False, request_json=False, meta_info =
"type": "json_object",
"schema": json_to_pydantic_schema('{"classification_label": "classification label goes here"}')
}

if self.trigger_function_call:
model_args["tools"] = self.tools
model_args["tool_choice"] = "auto"

logger.info(f'Request to litellm {model_args}')
try:
completion = await litellm.acompletion(**model_args)
Expand Down
Loading