-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
80 lines (68 loc) · 3.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import os
import random
import socket
import yaml
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import numpy as np
import torchvision
import models
import datasets
import utils
from models import DenoisingDiffusion
#CUDA_VISIBLE_DEVICES=1,2 python train.py --config daytime_64.yml --test_set Raindrop_DiT, RDiffusion, onego
#CUDA_VISIBLE_DEVICES=1,2 python train.py --config daytime_128.yml --test_set IDT, restormer
#CUDA_VISIBLE_DEVICES=1,2 python train.py --config daytime_256.yml --test_set ICRA256, Uformer, atgan
#CUDA_VISIBLE_DEVICES=1,2 python train.py --config daytime_64.yml --test_set Raindrop_DiT
#CUDA_VISIBLE_DEVICES=1,2 python train.py --config nighttime_256.yml --test_set atgan
def parse_args_and_config():
parser = argparse.ArgumentParser(description='Training Raindrop Clarity with Patch-Based Denoising Diffusion Models')
parser.add_argument("--config", type=str, default='nighttime_64.yml',
help="Path to the config file")
parser.add_argument('--resume', default='', type=str,
help='Path for checkpoint to load and resume')
parser.add_argument("--sampling_timesteps", type=int, default=25,
help="Number of implicit sampling steps for validation image patches")
parser.add_argument("--test_set", type=str, default='Raindrop_DiT',
help="restoration test set results: ['Raindrop_DiT', 'RDiffusion', 'IDT', 'restormer', 'Uformer', 'ICRA256', 'onego', 'atgan']")
parser.add_argument("--image_folder", default='results/images/', type=str,
help="Location to save restored validation image patches")
parser.add_argument('--seed', default=61, type=int, metavar='N',
help='Seed for initializing training (default: 61)')
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
# setup device to run
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
# data loading
print("=> using dataset '{}'".format(config.data.dataset))
DATASET = datasets.__dict__[config.data.dataset](config)
# create model
print("=> creating model...")
diffusion = DenoisingDiffusion(args, config)
diffusion.train(DATASET)
if __name__ == "__main__":
main()