-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCNN_crf_model.py
109 lines (88 loc) · 3.71 KB
/
CNN_crf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tifffile as tiff
from torch.utils.data import Dataset, DataLoader
import os
class TiffDataset(Dataset):
def __init__(self, image_dir, mask_dir):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_filenames = os.listdir(image_dir)
self.mask_filenames = os.listdir(mask_dir)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
image_filename = self.image_filenames[idx]
mask_filename = self.mask_filenames[idx]
image = tiff.imread(os.path.join(self.image_dir, image_filename))
mask = tiff.imread(os.path.join(self.mask_dir, mask_filename))
return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.long)
# Define the CRF layer in PyTorch
class CRFLayer(nn.Module):
def __init__(self, n_classes):
super(CRFLayer, self).__init__()
self.n_classes = n_classes
self.transitions = nn.Parameter(torch.randn(n_classes, n_classes))
def forward(self, logits, mask):
return self.compute_log_likelihood(logits, mask), self.decode(logits)
def compute_log_likelihood(self, logits, mask):
batch_size, _, height, width = logits.shape
logits = logits.permute(0, 2, 3, 1).contiguous().view(-1, height * width, self.n_classes)
# Compute the partition function (forward pass)
alpha = logits[:, 0, :]
for t in range(1, height * width):
alpha_t = []
for j in range(self.n_classes):
emit_score = logits[:, t, j]
trans_score = self.transitions[j, :]
alpha_t_j = alpha[:, :] + trans_score + emit_score
alpha_t.append(torch.logsumexp(alpha_t_j, dim=1))
alpha = torch.stack(alpha_t, dim=1)
log_likelihood = torch.logsumexp(alpha, dim=1).sum()
return -log_likelihood / batch_size
def decode(self, logits):
logits = logits.permute(0, 2, 3, 1).contiguous().view(-1, self.n_classes)
return torch.argmax(logits, dim=1).view(logits.shape[0], -1)
# Create a simple neural network with a CRF layer
class ToyModel(nn.Module):
def __init__(self, n_classes):
super(ToyModel, self).__init__()
self.conv = nn.Conv2d(1, n_classes, kernel_size=1)
self.crf = CRFLayer(n_classes)
def forward(self, x):
x = self.conv(x)
return x
def loss(self, x, y):
logits = self.forward(x)
return self.crf(logits, y)
def predict(self, x):
logits = self.forward(x)
return self.crf.decode(logits)
if __name__ == '__main__':
# Set the image and mask directories
image_dir = './data/train_1/imgs'
mask_dir = './data/train_1/masks'
# Create the dataset and data loader
dataset = TiffDataset(image_dir, mask_dir)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
# Instantiate the model, optimizer, and loss
n_classes = 2
model = ToyModel(n_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Train the model
num_epochs = 50
for epoch in range(num_epochs):
for batch_idx, (images, masks) in enumerate(data_loader):
optimizer.zero_grad()
loss = model.loss(images, masks)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
# Perform inference on a new image
image_test = tiff.imread('path/to/test/image.tiff')
image_test = torch.tensor(image_test, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
prediction = model.predict(image_test)
print('End CRF')