-
Notifications
You must be signed in to change notification settings - Fork 2
/
drop_scheduler.py
executable file
·25 lines (20 loc) · 1006 Bytes
/
drop_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"):
assert mode in ["standard", "early", "late"]
if mode == "standard":
return np.full(epochs * niter_per_ep, drop_rate)
early_iters = cutoff_epoch * niter_per_ep
late_iters = (epochs - cutoff_epoch) * niter_per_ep
if mode == "early":
assert schedule in ["constant", "linear"]
if schedule == 'constant':
early_schedule = np.full(early_iters, drop_rate)
elif schedule == 'linear':
early_schedule = np.linspace(drop_rate, 0, early_iters)
final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
elif mode == "late":
assert schedule in ["constant"]
early_schedule = np.full(early_iters, 0)
final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))
assert len(final_schedule) == epochs * niter_per_ep
return final_schedule