Skip to content

Commit 49fc5a5

Browse files
committed
Revert "Handle predictors with deferred annotations (#1772)"
This reverts commit 05900a7. This is needed to avoid breaking predictors that rely on __signature__ or partial.
1 parent 1d1daf6 commit 49fc5a5

File tree

3 files changed

+13
-63
lines changed

3 files changed

+13
-63
lines changed

python/cog/predictor.py

+13-25
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Type,
1919
Union,
2020
cast,
21-
get_type_hints,
2221
)
2322

2423
try:
@@ -283,18 +282,13 @@ def validate_input_type(
283282
)
284283

285284

286-
def get_input_create_model_kwargs(
287-
signature: inspect.Signature, input_types: Dict[str, Any]
288-
) -> Dict[str, Any]:
285+
def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]:
289286
create_model_kwargs = {}
290287

291288
order = 0
292289

293290
for name, parameter in signature.parameters.items():
294-
if name not in input_types:
295-
raise TypeError(f"No input type provided for parameter `{name}`.")
296-
297-
InputType = input_types[name] # pylint: disable=invalid-name
291+
InputType = parameter.annotation
298292

299293
validate_input_type(InputType, name)
300294

@@ -360,17 +354,13 @@ class Input(BaseModel):
360354
predict = get_predict(predictor)
361355
signature = inspect.signature(predict)
362356

363-
input_types = get_type_hints(predict)
364-
if "return" in input_types:
365-
del input_types["return"]
366-
367357
return create_model(
368358
"Input",
369359
__config__=None,
370360
__base__=BaseInput,
371361
__module__=__name__,
372362
__validators__=None,
373-
**get_input_create_model_kwargs(signature, input_types),
363+
**get_input_create_model_kwargs(signature),
374364
) # type: ignore
375365

376366

@@ -380,10 +370,9 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
380370
"""
381371

382372
predict = get_predict(predictor)
383-
384-
input_types = get_type_hints(predict)
385-
386-
if "return" not in input_types:
373+
signature = inspect.signature(predict)
374+
OutputType: Type[BaseModel]
375+
if signature.return_annotation is inspect.Signature.empty:
387376
raise TypeError(
388377
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
389378
@@ -398,7 +387,8 @@ def predict(
398387
...
399388
"""
400389
)
401-
OutputType = input_types.pop("return") # pylint: disable=invalid-name
390+
else:
391+
OutputType = signature.return_annotation
402392

403393
# The type that goes in the response is a list of the yielded type
404394
if get_origin(OutputType) is Iterator:
@@ -462,17 +452,13 @@ class TrainingInput(BaseModel):
462452
train = get_train(predictor)
463453
signature = inspect.signature(train)
464454

465-
input_types = get_type_hints(train)
466-
if "return" in input_types:
467-
del input_types["return"]
468-
469455
return create_model(
470456
"TrainingInput",
471457
__config__=None,
472458
__base__=BaseInput,
473459
__module__=__name__,
474460
__validators__=None,
475-
**get_input_create_model_kwargs(signature, input_types),
461+
**get_input_create_model_kwargs(signature),
476462
) # type: ignore
477463

478464

@@ -482,9 +468,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:
482468
"""
483469

484470
train = get_train(predictor)
471+
signature = inspect.signature(train)
485472

486-
input_types = get_type_hints(train)
487-
if "return" not in input_types:
473+
if signature.return_annotation is inspect.Signature.empty:
488474
raise TypeError(
489475
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
490476
@@ -499,6 +485,8 @@ def train(
499485
...
500486
"""
501487
)
488+
else:
489+
TrainingOutputType = signature.return_annotation
502490

503491
TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name
504492

test-integration/test_integration/fixtures/future-annotations-project/predict.py

-8
This file was deleted.

test-integration/test_integration/test_predict.py

-30
Original file line numberDiff line numberDiff line change
@@ -288,33 +288,3 @@ def test_predict_path_list_input(tmpdir_factory):
288288
)
289289
assert "test1" in result.stdout
290290
assert "test2" in result.stdout
291-
292-
293-
def test_predict_works_with_deferred_annotations():
294-
project_dir = Path(__file__).parent / "fixtures/future-annotations-project"
295-
296-
subprocess.check_call(
297-
["cog", "predict", "-i", "input=world"],
298-
cwd=project_dir,
299-
timeout=DEFAULT_TIMEOUT,
300-
)
301-
302-
303-
def test_predict_int_none_output():
304-
project_dir = Path(__file__).parent / "fixtures/int-none-output-project"
305-
306-
subprocess.check_call(
307-
["cog", "predict"],
308-
cwd=project_dir,
309-
timeout=DEFAULT_TIMEOUT,
310-
)
311-
312-
313-
def test_predict_string_none_output():
314-
project_dir = Path(__file__).parent / "fixtures/string-none-output-project"
315-
316-
subprocess.check_call(
317-
["cog", "predict"],
318-
cwd=project_dir,
319-
timeout=DEFAULT_TIMEOUT,
320-
)

0 commit comments

Comments
 (0)