-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
75 lines (53 loc) · 2.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import torch
import dataset
import model
def test(device, net, criterion, loader):
net.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in loader:
data = data.to(device)
target = target.to(device)
# Forward.
output = net(data)
# Sum up batch loss.
test_loss += criterion(output, target).item()
# Get the index of the max log-probability.
prediction = output.argmax(dim=1, keepdim=True)
correct += prediction.eq(target.view_as(prediction)).sum().item()
test_loss /= len(loader.dataset)
print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss,
correct,
len(loader.dataset),
100.0 * correct / len(loader.dataset)))
def main():
parser = argparse.ArgumentParser(description='MNIST PyTorch: Testing')
parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)')
parser.add_argument('--test-batch-size', type=int, default=1000, help='Test batch size (default: 1000)')
parser.add_argument('--model-name', type=str, default="mnist", help='Model name in disk (default: mnist.pt)')
args = parser.parse_args()
print('+------------------------------')
print('| Settings')
print('+------------------------------')
for arg in vars(args):
print('{}: {}'.format(arg.replace('_', ' ').capitalize(), getattr(args, arg)))
# Set the seed for generating random numbers.
torch.manual_seed(args.seed)
# Load the model from disk.
net = model.Model()
net.load_state_dict(torch.load(args.model_name + '.pt'))
# Loss function.
criterion = torch.nn.CrossEntropyLoss()
print('Criterion: {}'.format(criterion))
# Data loader.
test_loader = torch.utils.data.DataLoader(dataset.test_set(), batch_size=args.test_batch_size)
# Testing
print('+------------------------------')
print('| Testing')
print('+------------------------------')
test('cpu', net, criterion, test_loader)
if __name__ == '__main__':
main()