13
13
# limitations under the License.
14
14
#
15
15
# SPDX-License-Identifier: Apache-2.0
16
+ import sys # noqa
17
+
18
+ if sys .version_info < (3 , 11 ): # noqa
19
+ from typing_extensions import Self # noqa
20
+ else : # noqa
21
+ from typing import Self
22
+
16
23
17
24
from unittest .mock import MagicMock
18
25
19
26
import pytest
20
27
28
+
21
29
from genkit .plugins .compat_oai .models import OpenAIModel
22
30
from genkit .plugins .compat_oai .models .model_info import GPT_4
23
- from genkit .types import (
31
+ from genkit .core . typing import (
24
32
GenerateResponse ,
25
33
GenerateResponseChunk ,
26
34
Role ,
35
+ GenerateRequest
27
36
)
28
37
29
38
30
- def test_get_messages (sample_request ) :
39
+ def test_get_messages (sample_request : GenerateRequest ) -> None :
31
40
"""Test _get_messages method.
32
41
Ensures the method correctly converts GenerateRequest messages into OpenAI-compatible ChatMessage format.
33
42
"""
@@ -41,7 +50,7 @@ def test_get_messages(sample_request):
41
50
assert messages [1 ]['content' ] == 'Hello, world!'
42
51
43
52
44
- def test_get_openai_config (sample_request ) :
53
+ def test_get_openai_config (sample_request : GenerateRequest ) -> None :
45
54
"""
46
55
Test _get_openai_request_config method.
47
56
Ensures the method correctly constructs the OpenAI API configuration dictionary.
@@ -55,7 +64,7 @@ def test_get_openai_config(sample_request):
55
64
assert isinstance (openai_config ['messages' ], list )
56
65
57
66
58
- def test_generate (sample_request ) :
67
+ def test_generate (sample_request : GenerateRequest ) -> None :
59
68
"""
60
69
Test generate method calls OpenAI API and returns GenerateResponse.
61
70
"""
@@ -74,11 +83,13 @@ def test_generate(sample_request):
74
83
75
84
mock_client .chat .completions .create .assert_called_once ()
76
85
assert isinstance (response , GenerateResponse )
86
+ assert response .message is not None
77
87
assert response .message .role == Role .MODEL
88
+ assert response .message .content
78
89
assert response .message .content [0 ].root .text == 'Hello, user!'
79
90
80
91
81
- def test_generate_stream (sample_request ) :
92
+ def test_generate_stream (sample_request : GenerateRequest ) -> None :
82
93
"""Test generate_stream method ensures it processes streamed responses correctly."""
83
94
mock_client = MagicMock ()
84
95
@@ -87,10 +98,10 @@ def __init__(self, data: list[str]) -> None:
87
98
self ._data = data
88
99
self ._current = 0
89
100
90
- def __iter__ (self ):
101
+ def __iter__ (self ) -> Self :
91
102
return self
92
103
93
- def __next__ (self ):
104
+ def __next__ (self ) -> MagicMock :
94
105
if self ._current >= len (self ._data ):
95
106
raise StopIteration
96
107
@@ -112,7 +123,7 @@ def __next__(self):
112
123
model = OpenAIModel (model = GPT_4 , client = mock_client , registry = MagicMock ())
113
124
collected_chunks = []
114
125
115
- def callback (chunk : GenerateResponseChunk ):
126
+ def callback (chunk : GenerateResponseChunk ) -> None :
116
127
collected_chunks .append (chunk .content [0 ].root .text )
117
128
118
129
model .generate_stream (sample_request , callback )
0 commit comments