-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
110 lines (86 loc) · 3.88 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import DataLoader
from dataloaders.dataset import BaseDataSets
from lib.network import ODOC_cdr_graph
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='your_own_path', help='Name of Experiment')
parser.add_argument('--exp', type=str,
default='your_own_path', help='experiment_name')
parser.add_argument('--model', type=str,
default='your_own_path', help='model_name')
parser.add_argument('--max_iterations', type=int,
default=10000, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=56,
help='batch_size per gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.01,
help='segmentation network learning rate')
parser.add_argument('--patch_size', type=list, default=[256, 256],
help='patch size of network input')
parser.add_argument('--seed', type=int, default=1337, help='random seed')
# label and unlabel
parser.add_argument('--labeled_bs', type=int, default=28,
help='labeled_batch_size per gpu')
parser.add_argument('--labeled_num', type=int, default=10,
help='labeled data')
# costs
parser.add_argument('--dropout', type=float,
default=0.3, help='consistency')
parser.add_argument('--consistency_rampup', type=float,
default=200.0, help='consistency_rampup')
parser.add_argument('--viz', type=bool,
default=False, help='save_pred_masks')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
snapshot_path = "your_own_path".format(
args.exp, args.labeled_num, args.model, args.batch_size, args.dropout)
saved_model_path = os.path.join(snapshot_path, args.model + '_' + 'best_model_oc.pth')
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
model = ODOC_cdr_graph(channel=64, k1=5000, k2=70, dropout=args.dropout)
model = model.cuda()
model = nn.DataParallel(model)
model.load_state_dict(torch.load(saved_model_path))
model.eval()
db_test = BaseDataSets(base_dir=args.root_path, split="test")
testloader = DataLoader(db_test, batch_size=1, shuffle=False,
num_workers=1)
with torch.no_grad():
for i_batch, (sampled_batch, path) in enumerate(testloader):
volume_batch, label_batch, label_contour, label_cdr = sampled_batch['image'], sampled_batch['label'], sampled_batch['con'], sampled_batch['cdr']
volume_batch = volume_batch.cuda()
pred_region, pred_sdm, _, _, _, _ = model(volume_batch)
y_pred_OC_r = pred_region[:, 0, ...].cpu().detach().numpy().squeeze(0)
y_pred_OC_r = (y_pred_OC_r > 0.5).astype(np.uint8)
y_pred_OD_r = pred_region[:, 1, ...].cpu().detach().numpy().squeeze(0)
y_pred_OD_r = (y_pred_OD_r > 0.5).astype(np.uint8)
y_pred_OC_sdm = pred_sdm[:, 0, ...].cpu().detach().numpy().squeeze(0)
y_pred_OC_sdm[y_pred_OC_sdm > 0] = 1
y_pred_OD_sdm = pred_sdm[:, 1, ...].cpu().detach().numpy().squeeze(0)
y_pred_OD_sdm[y_pred_OD_sdm > 0] = 1
if args.viz:
plt.imshow(y_pred_OC_r, cmap='gray')
plt.show()
plt.imshow(y_pred_OD_r, cmap='gray')
plt.show()
plt.imshow(y_pred_OC_sdm, cmap='gray')
plt.show()
plt.imshow(y_pred_OD_sdm, cmap='gray')
plt.show()