Skip to content

feat: add preprocessing functions for step ingestion #169

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions literalai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from contextlib import redirect_stdout
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from traceloop.sdk import Traceloop
from typing_extensions import deprecated
Expand All @@ -29,6 +29,7 @@
MessageStepType,
Step,
StepContextManager,
StepDict,
TrueStepType,
step_decorator,
)
Expand Down Expand Up @@ -94,15 +95,55 @@ def __init__(

def to_sync(self) -> "LiteralClient":
if isinstance(self.api, AsyncLiteralAPI):
return LiteralClient(
sync_client = LiteralClient(
batch_size=self.event_processor.batch_size,
api_key=self.api.api_key,
url=self.api.url,
disabled=self.disabled,
)
if self.event_processor.preprocess_steps_function:
sync_client.event_processor.set_preprocess_steps_function(
self.event_processor.preprocess_steps_function
)

return sync_client
else:
return self # type: ignore

def set_preprocess_steps_function(
self,
preprocess_steps_function: Optional[
Callable[[List["StepDict"]], List["StepDict"]]
],
) -> None:
"""
Set a function that will preprocess steps before sending them to the API.
This can be used for tasks like PII removal or other data transformations.

The preprocess function should:
- Accept a list of StepDict objects as input
- Return a list of modified StepDict objects
- Be thread-safe as it will be called from a background thread
- Handle exceptions internally to avoid disrupting the event processing

Example:
```python
def remove_pii(steps):
# Process steps to remove PII data
for step in steps:
if step.get("content"):
step["content"] = my_pii_removal_function(step["content"])
return steps

client.set_preprocess_steps_function(remove_pii)
```

Args:
preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]):
Function that takes a list of steps and returns a processed list
"""
self.event_processor.set_preprocess_steps_function(preprocess_steps_function)

@deprecated("Use Literal.initialize instead")
def instrument_openai(self):
"""
Expand Down
47 changes: 45 additions & 2 deletions literalai/event_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
import time
import traceback
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Callable, List, Optional

logger = logging.getLogger(__name__)

Expand All @@ -31,7 +31,15 @@ class EventProcessor:
batch: List["StepDict"]
batch_timeout: float = 5.0

def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = False):
def __init__(
self,
api: "LiteralAPI",
batch_size: int = 1,
disabled: bool = False,
preprocess_steps_function: Optional[
Callable[[List["StepDict"]], List["StepDict"]]
] = None,
):
self.stop_event = threading.Event()
self.batch_size = batch_size
self.api = api
Expand All @@ -40,6 +48,7 @@ def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = Fals
self.processing_counter = 0
self.counter_lock = threading.Lock()
self.last_batch_time = time.time()
self.preprocess_steps_function = preprocess_steps_function
self.processing_thread = threading.Thread(
target=self._process_events, daemon=True
)
Expand All @@ -56,6 +65,22 @@ async def a_add_events(self, event: "StepDict"):
self.processing_counter += 1
await to_thread(self.event_queue.put, event)

def set_preprocess_steps_function(
self,
preprocess_steps_function: Optional[
Callable[[List["StepDict"]], List["StepDict"]]
],
):
"""
Set a function that will preprocess steps before sending them to the API.
The function should take a list of StepDict objects and return a list of processed StepDict objects.
This can be used for tasks like PII removal.

Args:
preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]): The preprocessing function
"""
self.preprocess_steps_function = preprocess_steps_function

def _process_events(self):
while True:
batch = []
Expand Down Expand Up @@ -83,6 +108,24 @@ def _process_events(self):

def _try_process_batch(self, batch: List):
try:
# Apply preprocessing function if it exists
if self.preprocess_steps_function is not None:
try:
processed_batch = self.preprocess_steps_function(batch)
# Only use the processed batch if it's valid
if processed_batch is not None and isinstance(
processed_batch, list
):
batch = processed_batch
else:
logger.warning(
"Preprocess function returned invalid result, using original batch"
)
except Exception as e:
logger.error(f"Error in preprocess function: {str(e)}")
logger.error(traceback.format_exc())
# Continue with the original batch

return self.api.send_steps(batch)
except Exception:
logger.error(f"Failed to send steps: {traceback.format_exc()}")
Expand Down
2 changes: 1 addition & 1 deletion literalai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.202"
__version__ = "0.1.3"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="literalai",
version="0.1.202", # update version in literalai/version.py
version="0.1.3", # update version in literalai/version.py
description="An SDK for observability in Python applications",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
94 changes: 94 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,97 @@ async def test_environment(self, staging_client: LiteralClient):
persisted_run = staging_client.api.get_step(run_id)
assert persisted_run is not None
assert persisted_run.environment == "staging"

@pytest.mark.timeout(5)
async def test_pii_removal(
self, client: LiteralClient, async_client: AsyncLiteralClient
):
"""Test that PII is properly removed by the preprocess function."""
import re

# Define a PII removal function
def remove_pii(steps):
# Patterns for common PII
email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
phone_pattern = r"\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b"
ssn_pattern = r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b"

for step in steps:
# Process content field if it exists
if "output" in step and step["output"]["content"]:
# Replace emails with [EMAIL REDACTED]
step["output"]["content"] = re.sub(
email_pattern, "[EMAIL REDACTED]", step["output"]["content"]
)

# Replace phone numbers with [PHONE REDACTED]
step["output"]["content"] = re.sub(
phone_pattern, "[PHONE REDACTED]", step["output"]["content"]
)

# Replace SSNs with [SSN REDACTED]
step["output"]["content"] = re.sub(
ssn_pattern, "[SSN REDACTED]", step["output"]["content"]
)

return steps

# Set the PII removal function on the client
client.set_preprocess_steps_function(remove_pii)

@client.thread
def thread_with_pii():
thread = client.get_current_thread()

# User message with PII
user_step = client.message(
content="My email is test@example.com and my phone is (123) 456-7890. My SSN is 123-45-6789.",
type="user_message",
metadata={"contact_info": "Call me at 987-654-3210"},
)
user_step_id = user_step.id

# Assistant message with PII reference
assistant_step = client.message(
content="I'll contact you at test@example.com", type="assistant_message"
)
assistant_step_id = assistant_step.id

return thread.id, user_step_id, assistant_step_id

# Run the thread
thread_id, user_step_id, assistant_step_id = thread_with_pii()

# Wait for processing to occur
client.flush()

# Fetch the steps and verify PII was removed
user_step = client.api.get_step(id=user_step_id)
assistant_step = client.api.get_step(id=assistant_step_id)

assert user_step
assert assistant_step

user_step_output = user_step.output["content"] # type: ignore

# Check user message
assert "test@example.com" not in user_step_output
assert "(123) 456-7890" not in user_step_output
assert "123-45-6789" not in user_step_output
assert "[EMAIL REDACTED]" in user_step_output
assert "[PHONE REDACTED]" in user_step_output
assert "[SSN REDACTED]" in user_step_output

assistant_step_output = assistant_step.output["content"] # type: ignore

# Check assistant message
assert "test@example.com" not in assistant_step_output
assert "[EMAIL REDACTED]" in assistant_step_output

# Clean up
client.api.delete_step(id=user_step_id)
client.api.delete_step(id=assistant_step_id)
client.api.delete_thread(id=thread_id)

# Reset the preprocess function to avoid affecting other tests
client.set_preprocess_steps_function(None)