diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index c5975f16..535861c7 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -134,6 +134,7 @@ class LiteLLMChatRuntime(Runtime): with the provider of your specified model. base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI. api_version (Optional[str]): API version, optional except for Azure. + extra_headers (Optional[Dict[str, str]]): Extra headers to be sent with the request. timeout: Timeout in seconds. """ @@ -303,6 +304,7 @@ class AsyncLiteLLMChatRuntime(AsyncRuntime): with the provider of your specified model. base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI. api_version (Optional[str]): API version, optional except for Azure. + extra_headers (Optional[Dict[str, str]]): Extra headers to be sent with the request. timeout: Timeout in seconds. """ diff --git a/tests/manual_test_scripts/auth_proxy_server.py b/tests/manual_test_scripts/auth_proxy_server.py new file mode 100644 index 00000000..efec5414 --- /dev/null +++ b/tests/manual_test_scripts/auth_proxy_server.py @@ -0,0 +1,75 @@ +import os +import httpx +from loguru import logger +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse + +""" +This script is a simple HTTP proxy server that forwards requests to a target URL. +It requires the TARGET_URL environment variable to be set to the target URL. +It also requires the EXPECTED_HEADER environment variable to be set to the expected Authorization header value. + +To install the dependencies, run the following command: +``` +pip install fastapi httpx loguru +``` + +To run the server: +``` +TARGET_URL=https://example.com EXPECTED_HEADER=secret uvicorn auth_proxy_server:app +``` + +This will forward all requests to `https://example.com` and check for the `Authorization` header to be equal to `secret`. +""" +app = FastAPI() + +TARGET_URL = os.getenv('TARGET_URL') +EXPECTED_HEADER = os.getenv('EXPECTED_HEADER') + + +async def proxy_request(request: Request): + # Check for authentication header + auth_header = request.headers.get("Authorization") + if not auth_header: + raise HTTPException(status_code=401, detail="Authorization header missing") + if EXPECTED_HEADER and auth_header != EXPECTED_HEADER: + raise HTTPException(status_code=403, detail=f"Invalid Authorization header." + f" Provided: {auth_header}. Required: {EXPECTED_HEADER}") + + # Prepare the URL for the proxied request + path = request.url.path + if request.url.query: + path += f"?{request.url.query}" + url = f"{TARGET_URL}{path}" + + # Prepare headers + headers = dict(request.headers) + headers["host"] = TARGET_URL.split("://")[1] + + logger.info(f"Forwarding request to {url}, headers: {headers}") + + # Create httpx client + async with httpx.AsyncClient(timeout=60) as client: + # Forward the request + response = await client.request( + method=request.method, + url=url, + headers=headers, + content=await request.body() + ) + + # Stream the response back to the client + return StreamingResponse( + response.aiter_bytes(), + status_code=response.status_code, + headers=response.headers + ) + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"]) +async def catch_all(request: Request, path: str): + return await proxy_request(request) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8010) \ No newline at end of file diff --git a/tests/test_agent_custom_model.py b/tests/test_agent_custom_model.py new file mode 100644 index 00000000..ce1a552f --- /dev/null +++ b/tests/test_agent_custom_model.py @@ -0,0 +1,47 @@ +import pytest +import pandas as pd +import responses +import asyncio + + +@responses.activate +def test_agent_with_custom_base_url(): + from adala.agents import Agent # type: ignore + + agent_json = { + "skills": [ + { + "type": "ClassificationSkill", + "name": "ClassificationResult", + "instructions": "", + "input_template": "Classify sentiment of the input text: {input}", + "field_schema": { + "output": { + "type": "string", + "enum": ["positive", "negative", "neutral"], + } + }, + } + ], + "runtimes": { + "default": { + "type": "AsyncLiteLLMChatRuntime", + "api_version": "v1", + "max_tokens": 4096, + "model": "openai/llama3.1", + "temperature": 0, + "batch_size": 100, + "timeout": 120, + "verbose": False, + "base_url": "http://localhost:11434/v1/", + "api_key": "ollama", + "auth_token": "SECRET-TEST-TOKEN", + } + }, + } + agent = Agent(**agent_json) + + df = pd.DataFrame([["I'm happy"], ["I'm sad"], ["I'm neutral"]], columns=["input"]) + + results = asyncio.run(agent.arun(input=df)) + print(results)