Skip to content

Commit

Permalink
Fix memory leak in AsyncCompletions.parse() with dynamically created …
Browse files Browse the repository at this point in the history
…models

This commit fixes a memory leak issue in AsyncCompletions.parse() when repeatedly
called with Pydantic models created via create_model(). The issue was occurring
because schema representations of models were being retained indefinitely.

The fix implements a WeakKeyDictionary cache that allows the schema objects to be
garbage-collected when the model types are no longer referenced elsewhere in code.

Added test cases to verify the fix prevents memory leaks with dynamically created
models.
  • Loading branch information
mousberg committed Feb 27, 2025
1 parent 3e69750 commit f1e3865
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import weakref
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never

Expand Down Expand Up @@ -28,6 +29,9 @@
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_tool_call import Function

# Cache to store weak references to schema objects
_schema_cache = weakref.WeakKeyDictionary()

ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
Expand Down Expand Up @@ -243,6 +247,10 @@ def type_to_response_format_param(
# can only be a `type`
response_format = cast(type, response_format)

# Check if we already have a schema for this type in the cache
if response_format in _schema_cache:
return _schema_cache[response_format]

json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

if is_basemodel_type(response_format):
Expand All @@ -254,11 +262,16 @@ def type_to_response_format_param(
else:
raise TypeError(f"Unsupported response_format type - {response_format}")

return {
schema_param = {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}

# Store a weak reference to the schema parameter
_schema_cache[response_format] = schema_param

return schema_param
101 changes: 101 additions & 0 deletions tests/lib/_parsing/test_memory_leak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest
import gc
import sys
from unittest.mock import AsyncMock, patch, MagicMock
from typing import List

import pytest
from pydantic import Field, create_model

from openai.resources.beta.chat.completions import AsyncCompletions
from openai.lib._parsing import type_to_response_format_param
from openai.lib._parsing._completions import _schema_cache

class TestMemoryLeak(unittest.TestCase):
def setUp(self):
# Clear the schema cache before each test
_schema_cache.clear()

def test_schema_cache_with_models(self):
"""Test if schema cache properly handles dynamic models and prevents memory leak"""

StepModel = create_model(
"Step",
explanation=(str, Field()),
output=(str, Field()),
)

# Create several models and ensure they're cached properly
models = []
for i in range(5):
model = create_model(
f"MathResponse{i}",
steps=(List[StepModel], Field()),
final_answer=(str, Field()),
)
models.append(model)

# Convert model to response format param
param = type_to_response_format_param(model)

# Check if the model is in the cache
self.assertIn(model, _schema_cache)

# Test that all models are in the cache
self.assertEqual(len(_schema_cache), 5)

# Let the models go out of scope and trigger garbage collection
models = None
gc.collect()

# After garbage collection, the cache should be empty or reduced
# since we're using weakref.WeakKeyDictionary
self.assertLess(len(_schema_cache), 5)

@pytest.mark.asyncio
async def test_async_completions_parse_memory():
"""Test if AsyncCompletions.parse() doesn't leak memory with dynamic models"""
StepModel = create_model(
"Step",
explanation=(str, Field()),
output=(str, Field()),
)

# Clear the cache and record initial state
_schema_cache.clear()
initial_cache_size = len(_schema_cache)

# Create a mock client
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()

# Create the AsyncCompletions instance with our mock client
completions = AsyncCompletions(mock_client)

# Simulate the issue by creating multiple models and making calls
models = []
for i in range(10):
# Create a new dynamic model each time
new_model = create_model(
f"MathResponse{i}",
steps=(List[StepModel], Field()),
final_answer=(str, Field()),
)
models.append(new_model)

# Convert to response format and check if it's in the cache
type_to_response_format_param(new_model)
assert new_model in _schema_cache

# Record cache size with all models referenced
cache_size_with_references = len(_schema_cache)

# Let the models go out of scope and trigger garbage collection
models = None
gc.collect()

# After garbage collection, the cache should be significantly reduced
cache_size_after_gc = len(_schema_cache)
assert cache_size_after_gc < cache_size_with_references
# The cache size should be close to the initial size (with some tolerance)
assert cache_size_after_gc < cache_size_with_references / 2

0 comments on commit f1e3865

Please # to comment.