-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcore.py
145 lines (96 loc) · 3.1 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Callable, Dict, List, TypeVar, Union
from inspect import signature
import torch
from torch import flatten, nn
T1 = TypeVar('T1')
T2 = TypeVar('T2')
T3 = TypeVar('T3')
class Flatten(nn.Module):
def __init__(self, start_dim) -> None:
super().__init__()
self.flatten = torch.nn.Flatten(start_dim=start_dim)
def forward(self, input):
if isinstance(input, tuple):
result = tuple(self.flatten(x) for x in input)
else:
result = self.flatten(input)
return result
def rename_dict(dict: dict, func: Callable[[str], str]):
result = {}
for key, val in dict.items():
result[func(key)] = val
return result
def take(dict: dict, names: List[str]):
result = {}
for name in names:
if name in dict:
result[name] = dict[name]
return result
def give(dict: dict, names: List[str]):
given = {}
left = {}
for key, val in dict.items():
if key in names:
given[key] = val
else:
left[key] = val
return given, left
def filter_out_dict(dict: dict, names: List[str]):
left = {}
for key, val in dict.items():
if key not in names:
left[key] = val
return left
def map_dict(
dict: Dict[T1, T2],
func: Union[Callable[[T2], T3], Callable[[T2, T1], T3]]
):
result = {}
func_params_count = len(signature(func).parameters)
if func_params_count == 1:
for key, val in dict.items():
result[key] = func(val) # type: ignore
elif func_params_count == 2:
for key, val in dict.items():
result[key] = func(val, key) # type: ignore
else:
raise ValueError('Wrong reduce function')
return result
def reduce_dict(
dict: Dict[T1, T2],
func: Union[Callable[[T3, T2], T3], Callable[[T3, T2, T1], T3]],
initial_value: T3
):
result = initial_value
func_params_count = len(signature(func).parameters)
if func_params_count == 2:
for val in dict.values():
result = func(result, val) # type: ignore
elif func_params_count == 3:
for key, val in dict.items():
result = func(result, val, key) # type: ignore
else:
raise ValueError('Wrong reduce function')
return result
def zip_dicts(dict1: Dict[T1, T2], dict2: Dict[T1, T2]):
result = {}
for key, val in dict1.items():
result[key] = (val, dict2[key])
return result
def split_by_dict(values: List[T2], dict: Dict[T1, int]):
result = {}
i = 0
for key, count in dict.items():
result[key] = values[i:i + count]
i += count
return result
def split_by_arrays(values: List[T2], counts: List[List[int]]):
result = []
i = 0
for local_counts in counts:
local_result = []
for count in local_counts:
local_result.append(values[i:i + count])
i += count
result.append(local_result)
return result