18
18
Type ,
19
19
Union ,
20
20
cast ,
21
- get_type_hints ,
22
21
)
23
22
24
23
try :
@@ -283,18 +282,13 @@ def validate_input_type(
283
282
)
284
283
285
284
286
- def get_input_create_model_kwargs (
287
- signature : inspect .Signature , input_types : Dict [str , Any ]
288
- ) -> Dict [str , Any ]:
285
+ def get_input_create_model_kwargs (signature : inspect .Signature ) -> Dict [str , Any ]:
289
286
create_model_kwargs = {}
290
287
291
288
order = 0
292
289
293
290
for name , parameter in signature .parameters .items ():
294
- if name not in input_types :
295
- raise TypeError (f"No input type provided for parameter `{ name } `." )
296
-
297
- InputType = input_types [name ] # pylint: disable=invalid-name
291
+ InputType = parameter .annotation
298
292
299
293
validate_input_type (InputType , name )
300
294
@@ -360,17 +354,13 @@ class Input(BaseModel):
360
354
predict = get_predict (predictor )
361
355
signature = inspect .signature (predict )
362
356
363
- input_types = get_type_hints (predict )
364
- if "return" in input_types :
365
- del input_types ["return" ]
366
-
367
357
return create_model (
368
358
"Input" ,
369
359
__config__ = None ,
370
360
__base__ = BaseInput ,
371
361
__module__ = __name__ ,
372
362
__validators__ = None ,
373
- ** get_input_create_model_kwargs (signature , input_types ),
363
+ ** get_input_create_model_kwargs (signature ),
374
364
) # type: ignore
375
365
376
366
@@ -380,10 +370,9 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
380
370
"""
381
371
382
372
predict = get_predict (predictor )
383
-
384
- input_types = get_type_hints (predict )
385
-
386
- if "return" not in input_types :
373
+ signature = inspect .signature (predict )
374
+ OutputType : Type [BaseModel ]
375
+ if signature .return_annotation is inspect .Signature .empty :
387
376
raise TypeError (
388
377
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
389
378
@@ -398,7 +387,8 @@ def predict(
398
387
...
399
388
"""
400
389
)
401
- OutputType = input_types .pop ("return" ) # pylint: disable=invalid-name
390
+ else :
391
+ OutputType = signature .return_annotation
402
392
403
393
# The type that goes in the response is a list of the yielded type
404
394
if get_origin (OutputType ) is Iterator :
@@ -462,17 +452,13 @@ class TrainingInput(BaseModel):
462
452
train = get_train (predictor )
463
453
signature = inspect .signature (train )
464
454
465
- input_types = get_type_hints (train )
466
- if "return" in input_types :
467
- del input_types ["return" ]
468
-
469
455
return create_model (
470
456
"TrainingInput" ,
471
457
__config__ = None ,
472
458
__base__ = BaseInput ,
473
459
__module__ = __name__ ,
474
460
__validators__ = None ,
475
- ** get_input_create_model_kwargs (signature , input_types ),
461
+ ** get_input_create_model_kwargs (signature ),
476
462
) # type: ignore
477
463
478
464
@@ -482,9 +468,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:
482
468
"""
483
469
484
470
train = get_train (predictor )
471
+ signature = inspect .signature (train )
485
472
486
- input_types = get_type_hints (train )
487
- if "return" not in input_types :
473
+ if signature .return_annotation is inspect .Signature .empty :
488
474
raise TypeError (
489
475
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
490
476
@@ -499,6 +485,8 @@ def train(
499
485
...
500
486
"""
501
487
)
488
+ else :
489
+ TrainingOutputType = signature .return_annotation
502
490
503
491
TrainingOutputType = input_types .pop ("return" ) # pylint: disable=invalid-name
504
492
0 commit comments