forked from Farama-Foundation/ViZDoom
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_pytorch.py
128 lines (101 loc) · 3.47 KB
/
test_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# E. Culurciello
# August 2017
from __future__ import division
from __future__ import print_function
from vizdoom import *
import itertools as it
from random import sample, randint, random
from time import time, sleep
import numpy as np
import skimage.color, skimage.transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
# NN learning settings
batch_size = 64
# Training regime
test_episodes_per_epoch = 100
# Other parameters
frame_repeat = 12
resolution = (30, 45)
episodes_to_watch = 10
model_savefile = "./model-doom.pth"
save_model = True
load_model = False
skip_learning = False
# Configuration file path
config_file_path = "../../scenarios/simpler_basic.cfg"
# config_file_path = "../../scenarios/rocket_basic.cfg"
# config_file_path = "../../scenarios/basic.cfg"
# Converts and down-samples the input image
def preprocess(img):
img = skimage.transform.resize(img, resolution)
img = img.astype(np.float32)
return img
class Net(nn.Module):
def __init__(self, available_actions_count):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=6, stride=3)
self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.fc1 = nn.Linear(192, 128)
self.fc2 = nn.Linear(128, available_actions_count)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 192)
x = F.relu(self.fc1(x))
return self.fc2(x)
def get_q_values(state):
state = torch.from_numpy(state)
state = Variable(state)
return model(state)
def get_best_action(state):
q = get_q_values(state)
m, index = torch.max(q, 1)
action = index.data.numpy()[0]
return action
# Creates and initializes ViZDoom environment.
def initialize_vizdoom(config_file_path):
print("Initializing doom...")
game = DoomGame()
game.load_config(config_file_path)
game.set_window_visible(True)
game.set_mode(Mode.PLAYER)
game.set_screen_format(ScreenFormat.GRAY8)
game.set_screen_resolution(ScreenResolution.RES_640X480)
game.init()
print("Doom initialized.")
return game
if __name__ == '__main__':
# Create Doom instance
game = initialize_vizdoom(config_file_path)
# Action = which buttons are pressed
n = game.get_available_buttons_size()
actions = [list(a) for a in it.product([0, 1], repeat=n)]
print("Loading model from: ", model_savefile)
model = torch.load(model_savefile)
print("======================================")
print("Testing trained neural network!")
# Reinitialize the game with window visible
game.set_window_visible(True)
game.set_mode(Mode.ASYNC_PLAYER)
game.init()
for _ in range(episodes_to_watch):
game.new_episode()
while not game.is_episode_finished():
state = preprocess(game.get_state().screen_buffer)
state = state.reshape([1, 1, resolution[0], resolution[1]])
best_action_index = get_best_action(state)
# Instead of make_action(a, frame_repeat) in order to make the animation smooth
game.set_action(actions[best_action_index])
for _ in range(frame_repeat):
game.advance_action()
sleep(0.03)
# Sleep between episodes
sleep(1.0)
score = game.get_total_reward()
print("Total score: ", score)