-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
51 lines (40 loc) · 1.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import copy
def load_json(file):
with open(file, 'r') as f:
return json.load(f)
def save_json(data, file):
with open(file, 'w') as f:
json.dump(data, f, indent=4)
class Statistics(object):
def __init__(self):
self.global_update = 0
self.previous_update = 0
self.global_stat = {}
self.previous_stat = {}
def update_dict(self, stat_dict):
for key, value in stat_dict.items():
self.global_stat[key] = self.global_stat.get(key, 0) + value
self.global_update += 1
def mean(self):
if self.global_update == 0:
return {}
mean_stat = {}
for key, value in self.global_stat.items():
mean_stat[key] = value / self.global_update
return mean_stat
def local_mean(self):
if self.global_update - self.previous_update == 0:
return {}
mean_stat = {}
for key, value in self.global_stat.items():
mean_stat[key] = (value - self.previous_stat.get(key, 0)) / \
(self.global_update - self.previous_update)
self.previous_update = self.global_update
self.previous_stat = copy.deepcopy(self.global_stat)
return mean_stat
def reset(self):
self.global_update = 0
self.previous_update = 0
self.global_stat = {}
self.previous_stat = {}