diff --git a/hypercoast/chla.py b/hypercoast/chla.py index c95506a..ce4e1a6 100644 --- a/hypercoast/chla.py +++ b/hypercoast/chla.py @@ -14,3 +14,61 @@ from rasterio.transform import from_origin from rasterio.warp import reproject, Resampling from scipy.interpolate import griddata + + +class VAE(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + + # encoder + self.encoder_layer = nn.Sequential( + nn.Linear(input_dim, 64), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.2), + nn.Linear(64, 64), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.2), + ) + + self.fc1 = nn.Linear(64, 32) + self.fc2 = nn.Linear(64, 32) + + # decoder + self.decoder = nn.Sequential( + nn.Linear(32, 64), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.2), + nn.Linear(64, 64), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.2), + nn.Linear(64, output_dim), + nn.Softplus(), + ) + + def encode(self, x): + x = self.encoder_layer(x) + mu = self.fc1(x) + log_var = self.fc2(x) + return mu, log_var + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + z = mu + eps * std + return z + + def decode(self, z): + return self.decoder(z) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_reconstructed = self.decode(z) + return x_reconstructed, mu, log_var + + +def loss_function(recon_x, x, mu, log_var): + L1 = F.l1_loss(recon_x, x, reduction="mean") + BCE = F.mse_loss(recon_x, x, reduction="mean") + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) + return L1