forked from yqsun91/WACCM-Emulation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModel.py
106 lines (80 loc) · 2.68 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import netCDF4 as nc
import numpy as np
import scipy.stats as st
import xarray as xr
import torch
from torch import nn
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
# Required for feeding the data iinto NN.
class myDataset(Dataset):
def __init__(self, X, Y):
self.features = torch.tensor(X, dtype=torch.float64)
self.labels = torch.tensor(Y, dtype=torch.float64)
def __len__(self):
return len(self.features.T)
def __getitem__(self, idx):
feature = self.features[:, idx]
label = self.labels[:, idx]
return feature, label
# The NN model.
class FullyConnected(nn.Module):
def __init__(self):
super(FullyConnected, self).__init__()
self.linear_stack = nn.Sequential(
nn.Linear(564, 5000, dtype=torch.float64),
nn.SiLU(),
nn.Linear(5000, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 500, dtype=torch.float64),
nn.SiLU(),
nn.Linear(500, 140, dtype=torch.float64),
)
def forward(self, X):
return self.linear_stack(X)
# training loop
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
avg_loss = 0
for batch, (X, Y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, Y)
# Backpropagation
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
with torch.no_grad():
avg_loss += loss.item()
avg_loss /= len(dataloader)
return avg_loss
# validating loop
def val_loop(dataloader, model, loss_fn):
avg_loss = 0
with torch.no_grad():
for batch, (X, Y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, Y)
avg_loss += loss.item()
avg_loss /= len(dataloader)
return avg_loss