Skip to content

Commit b5c7443

Browse files
Cleanup in get_enum_from_fn (#8852)
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 8bed9d8 commit b5c7443

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchvision/models/_api.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from inspect import signature
99
from types import ModuleType
10-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
10+
from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
1111

1212
from torch import nn
1313

@@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
168168
if "weights" not in sig.parameters:
169169
raise ValueError("The method is missing the 'weights' argument.")
170170

171-
ann = signature(fn).parameters["weights"].annotation
171+
ann = sig.parameters["weights"].annotation
172172
weights_enum = None
173173
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
174174
weights_enum = ann
175175
else:
176176
# handle cases like Union[Optional, T]
177-
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
178-
for t in ann.__args__: # type: ignore[union-attr]
177+
for t in get_args(ann): # type: ignore[union-attr]
179178
if isinstance(t, type) and issubclass(t, WeightsEnum):
180179
weights_enum = t
181180
break

0 commit comments

Comments
 (0)