Skip to content

Commit 6b5785a

Browse files
committed
fix(py/genkit): fixing type check errors of mypy
1 parent b9bd210 commit 6b5785a

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

py/plugins/compat-oai/tests/test_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytest
2121

22-
from genkit.ai import ActionRunContext
22+
from genkit.core.action import ActionRunContext
2323
from genkit.plugins.compat_oai.models import OpenAIModelHandler
2424
from genkit.plugins.compat_oai.models.model import OpenAIModel
2525
from genkit.plugins.compat_oai.models.model_info import (

py/plugins/compat-oai/tests/test_model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,30 @@
1313
# limitations under the License.
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
16+
import sys # noqa
17+
18+
if sys.version_info < (3, 11): # noqa
19+
from typing_extensions import Self # noqa
20+
else: # noqa
21+
from typing import Self
22+
1623

1724
from unittest.mock import MagicMock
1825

1926
import pytest
2027

28+
2129
from genkit.plugins.compat_oai.models import OpenAIModel
2230
from genkit.plugins.compat_oai.models.model_info import GPT_4
23-
from genkit.types import (
31+
from genkit.core.typing import (
2432
GenerateResponse,
2533
GenerateResponseChunk,
2634
Role,
35+
GenerateRequest
2736
)
2837

2938

30-
def test_get_messages(sample_request):
39+
def test_get_messages(sample_request: GenerateRequest) -> None:
3140
"""Test _get_messages method.
3241
Ensures the method correctly converts GenerateRequest messages into OpenAI-compatible ChatMessage format.
3342
"""
@@ -41,7 +50,7 @@ def test_get_messages(sample_request):
4150
assert messages[1]['content'] == 'Hello, world!'
4251

4352

44-
def test_get_openai_config(sample_request):
53+
def test_get_openai_config(sample_request: GenerateRequest) -> None:
4554
"""
4655
Test _get_openai_request_config method.
4756
Ensures the method correctly constructs the OpenAI API configuration dictionary.
@@ -55,7 +64,7 @@ def test_get_openai_config(sample_request):
5564
assert isinstance(openai_config['messages'], list)
5665

5766

58-
def test_generate(sample_request):
67+
def test_generate(sample_request: GenerateRequest) -> None:
5968
"""
6069
Test generate method calls OpenAI API and returns GenerateResponse.
6170
"""
@@ -74,11 +83,13 @@ def test_generate(sample_request):
7483

7584
mock_client.chat.completions.create.assert_called_once()
7685
assert isinstance(response, GenerateResponse)
86+
assert response.message is not None
7787
assert response.message.role == Role.MODEL
88+
assert response.message.content
7889
assert response.message.content[0].root.text == 'Hello, user!'
7990

8091

81-
def test_generate_stream(sample_request):
92+
def test_generate_stream(sample_request: GenerateRequest) -> None:
8293
"""Test generate_stream method ensures it processes streamed responses correctly."""
8394
mock_client = MagicMock()
8495

@@ -87,10 +98,10 @@ def __init__(self, data: list[str]) -> None:
8798
self._data = data
8899
self._current = 0
89100

90-
def __iter__(self):
101+
def __iter__(self) -> Self:
91102
return self
92103

93-
def __next__(self):
104+
def __next__(self) -> MagicMock:
94105
if self._current >= len(self._data):
95106
raise StopIteration
96107

@@ -112,7 +123,7 @@ def __next__(self):
112123
model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock())
113124
collected_chunks = []
114125

115-
def callback(chunk: GenerateResponseChunk):
126+
def callback(chunk: GenerateResponseChunk) -> None:
116127
collected_chunks.append(chunk.content[0].root.text)
117128

118129
model.generate_stream(sample_request, callback)

py/plugins/compat-oai/tests/test_tool_calling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from genkit.plugins.compat_oai.models import OpenAIModel
2424
from genkit.plugins.compat_oai.models.model_info import GPT_4
25-
from genkit.types import GenerateRequest, GenerateResponseChunk, TextPart, ToolRequestPart
25+
from genkit.core.typing import GenerateRequest, GenerateResponseChunk, TextPart, ToolRequestPart
2626

2727

2828
def test_generate_with_tool_calls_executes_tools(sample_request: GenerateRequest) -> None:
@@ -143,4 +143,5 @@ def callback(chunk: GenerateResponseChunk):
143143
assert tool_part.tool_request.ref == 'tool123'
144144

145145
accumulated_output = reduce(lambda res, tool_call: res + tool_call.tool_request.input, collected_chunks, '')
146+
assert accumulated_output is not None
146147
assert json.loads(accumulated_output) == {'a': 1}

0 commit comments

Comments
 (0)