-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathfreeinit_utils.py
118 lines (97 loc) · 3.59 KB
/
freeinit_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.fft as fft
import math
def freq_mix_3d(x, noise, LPF):
"""
Noise reinitialization.
Args:
x: diffused latent
noise: randomly sampled noise
LPF: low pass filter
"""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
return x_mixed
def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the gaussian low pass filter mask.
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
return mask
def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
"""
Compute the butterworth low pass filter mask.
Args:
shape: shape of the filter (volume)
n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
return mask
def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the ideal low pass filter mask.
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
return mask
def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the ideal low pass filter mask (approximated version).
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
threshold_s = round(int(H // 2) * d_s)
threshold_t = round(T // 2 * d_t)
cframe, crow, ccol = T // 2, H // 2, W //2
mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
return mask