Skip to content

refactor: add fixtures for test loggers and update database schema #10546

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-litellm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 5
timeout-minutes: 10

steps:
- uses: actions/checkout@v4
Expand Down
7 changes: 7 additions & 0 deletions tests/litellm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

import pytest

# Register custom markers to avoid warnings
def pytest_configure(config):
"""
Register custom markers to avoid warnings.
"""
config.addinivalue_line("markers", "flaky: mark test as flaky, will be rerun if it fails")

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
Expand Down
58 changes: 32 additions & 26 deletions tests/litellm/integrations/arize/test_arize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,34 +178,40 @@ def test_arize_set_attributes():
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)


class TestArizeLogger(CustomLogger):
@pytest.fixture
def arize_test_logger():
"""
Custom logger implementation to capture standard_callback_dynamic_params.
Used to verify that dynamic config keys are being passed to callbacks.
Fixture to provide a test logger for Arize tests.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = None

async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# Capture dynamic params and print them for verification
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
self.standard_callback_dynamic_params = kwargs.get(
"standard_callback_dynamic_params"
)
class TestArizeLogger(CustomLogger):
"""
Custom logger implementation to capture standard_callback_dynamic_params.
Used to verify that dynamic config keys are being passed to callbacks.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = None

async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# Capture dynamic params and print them for verification
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
self.standard_callback_dynamic_params = kwargs.get(
"standard_callback_dynamic_params"
)

return TestArizeLogger()


@pytest.mark.asyncio
async def test_arize_dynamic_params():
"""
Test to ensure that dynamic Arize keys (API key and space key)
are received inside the callback logger at runtime.
"""
test_arize_logger = TestArizeLogger()
litellm.callbacks = [test_arize_logger]
async def test_arize_dynamic_params(arize_test_logger):
"""Test to ensure that dynamic Arize keys (API key and space key)
are received inside the callback logger at runtime."""
# Use the fixture instead of creating a new instance
test_logger = arize_test_logger
litellm.callbacks = [test_logger]

# Perform a mocked async completion call to trigger logging
await litellm.acompletion(
Expand All @@ -220,12 +226,12 @@ async def test_arize_dynamic_params():
await asyncio.sleep(2)

# Assert dynamic parameters were received in the callback
assert test_arize_logger.standard_callback_dynamic_params is not None
assert test_logger.standard_callback_dynamic_params is not None
assert (
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
test_logger.standard_callback_dynamic_params.get("arize_api_key")
== "test_api_key_dynamic"
)
assert (
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
test_logger.standard_callback_dynamic_params.get("arize_space_key")
== "test_space_key_dynamic"
)
76 changes: 39 additions & 37 deletions tests/litellm/integrations/test_custom_prompt_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,44 @@ def setup_anthropic_api_key(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")


class TestCustomPromptManagement(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
print(
"TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
prompt_id,
)
if prompt_id == "test_prompt_id":
messages = [
{"role": "user", "content": "This is the prompt for test_prompt_id"},
]
return model, messages, non_default_params
elif prompt_id == "prompt_with_variables":
content = "Hello, {name}! You are {age} years old and live in {city}."
content_with_variables = content.format(**(prompt_variables or {}))
messages = [
{"role": "user", "content": content_with_variables},
]
return model, messages, non_default_params
else:
return model, messages, non_default_params
@pytest.fixture
def custom_prompt_manager():
class TestCustomPromptManager(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
print(
"TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
prompt_id,
)
if prompt_id == "test_prompt_id":
messages = [
{"role": "user", "content": "This is the prompt for test_prompt_id"},
]
return model, messages, non_default_params
elif prompt_id == "prompt_with_variables":
content = "Hello, {name}! You are {age} years old and live in {city}."
content_with_variables = content.format(**(prompt_variables or {}))
messages = [
{"role": "user", "content": content_with_variables},
]
return model, messages, non_default_params
else:
return model, messages, non_default_params

return TestCustomPromptManager()


@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
async def test_custom_prompt_management_with_prompt_id(monkeypatch, custom_prompt_manager):
# Use the fixture instead of instantiating directly
litellm.callbacks = [custom_prompt_manager]

# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
Expand All @@ -83,9 +87,8 @@ async def test_custom_prompt_management_with_prompt_id(monkeypatch):


@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables(custom_prompt_manager):
litellm.callbacks = [custom_prompt_manager]

# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
Expand All @@ -112,9 +115,8 @@ async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():


@pytest.mark.asyncio
async def test_custom_prompt_management_without_prompt_id():
custom_prompt_management = TestCustomPromptManagement()
litellm.callbacks = [custom_prompt_management]
async def test_custom_prompt_management_without_prompt_id(custom_prompt_manager):
litellm.callbacks = [custom_prompt_manager]

# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
Expand Down
Loading