-
Notifications
You must be signed in to change notification settings - Fork 4
/
convlstm_net.py
75 lines (56 loc) · 2.63 KB
/
convlstm_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
pad = kernel_size[0] // 2, kernel_size[1] // 2
self.conv = nn.Conv2d(in_channels=input_dim + hidden_dim,
out_channels=4 * hidden_dim,
kernel_size=kernel_size,
padding=pad)
def initialize(self, inputs):
device = inputs.device
batch_size, _, height, width = inputs.size()
self.hidden_state = torch.zeros(batch_size, self.hidden_dim, height, width, device=device)
self.cell_state = torch.zeros(batch_size, self.hidden_dim, height, width, device=device)
self.memory_state = torch.zeros(batch_size, self.hidden_dim, height, width, device=device)
def forward(self, inputs, first_step=False):
if first_step:
self.initialize(inputs)
combined = torch.cat([inputs, self.hidden_state], dim=1)
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
self.cell_state = f * self.cell_state + i * g
self.hidden_state = o * torch.tanh(self.cell_state)
return self.hidden_state
class ConvLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, kernel_size):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = len(hidden_dim)
layers = []
for i in range(self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
layers.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=kernel_size))
self.layers = nn.ModuleList(layers)
self.conv_output = nn.Conv2d(self.hidden_dim[-1], output_dim, kernel_size=1)
def forward(self, input_x, land_mask):
assert len(input_x.shape) == 5
input_frames = input_x.size(1)
for t in range(input_frames):
input_ = input_x[:, t]
first_step = (t == 0)
for layer_idx in range(self.num_layers):
input_ = self.layers[layer_idx](input_, first_step=first_step)
output = self.conv_output(input_)[:, None]
output = torch.clamp(output, 0, 1)
return output