-
Notifications
You must be signed in to change notification settings - Fork 0
/
ModuleList.patch
113 lines (96 loc) · 3.75 KB
/
ModuleList.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
--- /opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py
+++ /opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py
@@ -5,7 +5,7 @@
modules it contains are properly registered, and will be visible by all
:class:`~torch.nn.Module` methods.
- Args:
+ Arguments:
modules (iterable, optional): an iterable of modules to add
Example::
@@ -22,9 +22,7 @@
return x
"""
- _modules: Dict[str, Module] # type: ignore[assignment]
-
- def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
+ def __init__(self, modules=None):
super(ModuleList, self).__init__()
if modules is not None:
self += modules
@@ -38,18 +36,17 @@
idx += len(self)
return str(idx)
- @_copy_to_script_wrapper
- def __getitem__(self, idx: int) -> Union[Module, 'ModuleList']:
+ def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
- def __setitem__(self, idx: int, module: Module) -> None:
+ def __setitem__(self, idx, module):
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)
- def __delitem__(self, idx: Union[int, slice]) -> None:
+ def __delitem__(self, idx):
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
@@ -59,33 +56,24 @@
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
- @_copy_to_script_wrapper
- def __len__(self) -> int:
+ def __len__(self):
return len(self._modules)
- @_copy_to_script_wrapper
- def __iter__(self) -> Iterator[Module]:
+ def __iter__(self):
return iter(self._modules.values())
- def __iadd__(self, modules: Iterable[Module]) -> 'ModuleList':
+ def __iadd__(self, modules):
return self.extend(modules)
- def __add__(self, other: Iterable[Module]) -> 'ModuleList':
- combined = ModuleList()
- for i, module in enumerate(chain(self, other)):
- combined.add_module(str(i), module)
- return combined
-
- @_copy_to_script_wrapper
def __dir__(self):
keys = super(ModuleList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
- def insert(self, index: int, module: Module) -> None:
+ def insert(self, index, module):
r"""Insert a given module before a given index in the list.
- Args:
+ Arguments:
index (int): index to insert.
module (nn.Module): module to insert
"""
@@ -93,19 +81,19 @@
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
- def append(self, module: Module) -> 'ModuleList':
+ def append(self, module):
r"""Appends a given module to the end of the list.
- Args:
+ Arguments:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
- def extend(self, modules: Iterable[Module]) -> 'ModuleList':
+ def extend(self, modules):
r"""Appends modules from a Python iterable to the end of the list.
- Args:
+ Arguments:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, container_abcs.Iterable):
@@ -116,5 +104,3 @@
self.add_module(str(offset + i), module)
return self
- # remove forward alltogether to fallback on Module's _forward_unimplemented
-