forked from MidKnightXI/BlurWarp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
106 lines (78 loc) · 3.18 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pandas as pd
import torch
import torch.backends.mps
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets.folder import default_loader
import torchvision.models as models
def setup_device():
DEVICE = torch.device("cpu")
# if torch.cuda.is_available(): # Uncomment for Nvidia GPUs
# DEVICE = torch.device("cuda")
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
print(f"Training on {DEVICE}")
return DEVICE
class BlurDetectionResNet(nn.Module):
def __init__(self):
super(BlurDetectionResNet, self).__init__()
self.resnet = models.resnet50()
num_ftrs = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_ftrs, 1)
def forward(self, x):
x = self.resnet(x)
return torch.sigmoid(x)
class BlurrySharpDataset(torch.utils.data.Dataset):
def __init__(self, annotations, transform=None):
self.data = annotations
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = self.data.iloc[idx, 0]
label_str = self.data.iloc[idx, 1]
image = default_loader(img_path)
if self.transform:
image = self.transform(image)
label = 1 if label_str == 'blurry' else 0
return image, label
def train_blur_detection_model(annotations_path, output_model_path):
EPOCH = 30
DEVICE = setup_device()
annotations = pd.read_csv(annotations_path)
transform = transforms.Compose([
transforms.Resize((256, 256), antialias=True),
transforms.ToTensor(),
])
train_dataset = BlurrySharpDataset(annotations, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
model = BlurDetectionResNet()
model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
for epoch in range(EPOCH):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
data = data.to(DEVICE)
label_batch = target.to(DEVICE).view(-1, 1) # Reshape target labels to match model output shape
output = model(data)
loss = criterion(output, label_batch.float())
loss.backward()
optimizer.step()
pred = (output >= 0.5).float() # Round predictions for accuracy calculation
correct = pred.eq(label_batch).sum().item()
precision = correct / len(label_batch)
if batch_idx % 100 == 0:
print(f'Epoch [{epoch+1}/{EPOCH}], Loss: {loss.item():.4f}, Precision: {precision:.4f}')
scheduler.step()
model.to(torch.device("cpu"))
torch.save(model.state_dict(), output_model_path)
print(f'Model saved to {output_model_path}')
if __name__ == '__main__':
annotations_path = 'dataset/annotations.csv'
output_model_path = 'blur_detection_model.tch'
train_blur_detection_model(annotations_path, output_model_path)