-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmymodel.py
31 lines (23 loc) · 918 Bytes
/
mymodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
class CCANet(torch.nn.Module):
def __init__(self, num_class=5):
super(CCANet, self).__init__()
# network layers
self.conv1 = torch.nn.Conv2d(3, 20, kernel_size=3, stride=1, padding=1)
self.conv2 = torch.nn.Conv2d(20, 50, kernel_size=3, stride=1, padding=1)
self.relu = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.linear1 = torch.nn.Linear(50*32*32, 250)
self.linear2 = torch.nn.Linear(250, num_class)
def forward(self, x):
x = self.maxpool(self.relu(self.conv1(x)))
x = self.maxpool(self.relu(self.conv2(x)))
x = x.flatten()
h = self.relu(self.linear1(x))
pred = self.linear2(h)
return pred
if __name__ == '__main__':
model = CCANet()
d_input = torch.randn((1, 3, 128, 128))
pred = model(d_input)
print(pred)