Skip to content

Commit 79e92c5

Browse files
authoredJul 26, 2022
Merge pull request #1043 from allenai/nanna-timeout
Improvements to how timeouts are handled.
2 parents f44c798 + 6303e01 commit 79e92c5

File tree

6 files changed

+226
-68
lines changed

6 files changed

+226
-68
lines changed
 

‎ai2thor/controller.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import shutil
2020
import subprocess
2121
import time
22+
import traceback
2223
import uuid
2324
import warnings
2425
from collections import defaultdict, deque
2526
from functools import lru_cache
2627
from itertools import product
2728
from platform import architecture as platform_architecture
2829
from platform import system as platform_system
30+
from typing import Dict, Any, Union, Optional
2931

3032
import numpy as np
3133

@@ -392,6 +394,8 @@ def __init__(
392394
server_class=None,
393395
gpu_device=None,
394396
platform=None,
397+
server_timeout: Optional[float] = 100.0,
398+
server_start_timeout: float = 300.0,
395399
**unity_initialization_parameters,
396400
):
397401
self.receptacle_nearest_pivot_points = {}
@@ -401,6 +405,11 @@ def __init__(
401405
self.width = width
402406
self.height = height
403407

408+
self.server_timeout = server_timeout
409+
self.server_start_timeout = server_start_timeout
410+
assert self.server_timeout is None or 0 < self.server_timeout
411+
assert 0 < self.server_start_timeout
412+
404413
self.last_event = None
405414
self.scene = None
406415
self._scenes_in_build = None
@@ -577,7 +586,7 @@ def __init__(
577586
init_return = event.metadata["actionReturn"]
578587
if init_return:
579588
self.server.set_init_params(init_return)
580-
logging.info("Initialize return: {}".format(init_return))
589+
logging.info(f"Initialize return: {init_return}")
581590

582591
def _build_server(self, host, port, width, height):
583592

@@ -593,20 +602,22 @@ def _build_server(self, host, port, width, height):
593602

594603
if self.server_class == ai2thor.wsgi_server.WsgiServer:
595604
self.server = ai2thor.wsgi_server.WsgiServer(
596-
host,
605+
host=host,
606+
timeout=self.server_timeout,
597607
port=port,
598-
depth_format=self.depth_format,
599-
add_depth_noise=self.add_depth_noise,
600608
width=width,
601609
height=height,
610+
depth_format=self.depth_format,
611+
add_depth_noise=self.add_depth_noise,
602612
)
603613

604614
elif self.server_class == ai2thor.fifo_server.FifoServer:
605615
self.server = ai2thor.fifo_server.FifoServer(
606-
depth_format=self.depth_format,
607-
add_depth_noise=self.add_depth_noise,
608616
width=width,
609617
height=height,
618+
timeout=self.server_timeout,
619+
depth_format=self.depth_format,
620+
add_depth_noise=self.add_depth_noise,
610621
)
611622

612623
def __enter__(self):
@@ -913,9 +924,9 @@ def multi_step_physics(self, action, timeStep=0.05, max_steps=20):
913924

914925
return events
915926

916-
def step(self, action=None, **action_args):
927+
def step(self, action: Union[str, Dict[str, Any]]=None, **action_args):
917928

918-
if type(action) is dict:
929+
if isinstance(action, Dict):
919930
action = copy.deepcopy(action) # prevent changes from leaking
920931
else:
921932
action = dict(action=action)
@@ -960,18 +971,27 @@ def step(self, action=None, **action_args):
960971
self.server.send(action)
961972
try:
962973
self.last_event = self.server.receive()
963-
except UnityCrashException as e:
974+
except UnityCrashException:
964975
self.server.stop()
965976
self.server = None
966977
# we don't need to pass port or host, since this Exception
967978
# is only thrown from the FifoServer, start_unity is also
968979
# not passed since Unity would have to have been started
969980
# for this to be thrown
970-
message = "Restarting unity due to crash: %s" % e
981+
message = (
982+
f"Restarting unity due to crash when when running action {action}"
983+
f" in scene {self.last_event.metadata['sceneName']}:\n{traceback.format_exc()}"
984+
)
971985
warnings.warn(message)
972986
self.start(width=self.width, height=self.height, x_display=self.x_display)
973987
self.reset()
974988
raise RestartError(message)
989+
except Exception as e:
990+
self.server.stop()
991+
raise (TimeoutError if isinstance(e, TimeoutError) else RuntimeError)(
992+
f"Error encountered when running action {action}"
993+
f" in scene {self.last_event.metadata['sceneName']}."
994+
)
975995

976996
if not self.last_event.metadata["lastActionSuccess"]:
977997
if self.last_event.metadata["errorCode"] in [
@@ -1056,7 +1076,7 @@ def _start_unity_thread(self, env, width, height, server_params, image_name):
10561076
pass
10571077

10581078
self.unity_pid = proc.pid
1059-
atexit.register(lambda: proc.poll() is None and proc.kill())
1079+
atexit.register(lambda: self.server.stop())
10601080

10611081
@property
10621082
def tmp_dir(self):
@@ -1336,7 +1356,7 @@ def start(
13361356
self._start_unity_thread(env, width, height, unity_params, image_name)
13371357

13381358
# receive the first request
1339-
self.last_event = self.server.receive()
1359+
self.last_event = self.server.receive(timeout=self.server_start_timeout)
13401360

13411361
# we should be able to get rid of this since we check the resolution in .reset()
13421362
if self.server.unity_proc is not None and (height < 300 or width < 300):

‎ai2thor/fifo_server.py

+73-18
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@
55
Handles all communication with Unity through a Flask service. Messages
66
are sent to the controller using a pair of request/response queues.
77
"""
8-
import ai2thor.server
98
import json
10-
import msgpack
119
import os
10+
import select
11+
import struct
1212
import tempfile
13-
from ai2thor.exceptions import UnityCrashException
14-
from enum import IntEnum, unique
13+
import time
1514
from collections import defaultdict
16-
import struct
15+
from enum import IntEnum, unique
16+
from io import TextIOWrapper
17+
from typing import Optional
18+
19+
import msgpack
20+
21+
import ai2thor.server
22+
from ai2thor.exceptions import UnityCrashException
23+
1724

1825
# FifoFields
1926
@unique
@@ -45,17 +52,18 @@ class FifoServer(ai2thor.server.Server):
4552

4653
def __init__(
4754
self,
48-
width,
49-
height,
55+
width: int,
56+
height: int,
57+
timeout: Optional[float] = 100.0,
5058
depth_format=ai2thor.server.DepthFormat.Meters,
51-
add_depth_noise=False,
59+
add_depth_noise: bool = False,
5260
):
5361

5462
self.tmp_dir = tempfile.TemporaryDirectory()
5563
self.server_pipe_path = os.path.join(self.tmp_dir.name, "server.pipe")
5664
self.client_pipe_path = os.path.join(self.tmp_dir.name, "client.pipe")
57-
self.server_pipe = None
58-
self.client_pipe = None
65+
self.server_pipe: Optional[TextIOWrapper] = None
66+
self.client_pipe: Optional[TextIOWrapper] = None
5967
self.raw_metadata = None
6068
self.raw_files = None
6169
self._last_action_message = None
@@ -93,19 +101,52 @@ def __init__(
93101
}
94102

95103
self.eom_header = self._create_header(FieldType.END_OF_MESSAGE, b"")
96-
super().__init__(width, height, depth_format, add_depth_noise)
104+
super().__init__(
105+
width=width,
106+
height=height,
107+
timeout=timeout,
108+
depth_format=depth_format,
109+
add_depth_noise=add_depth_noise
110+
)
97111

98112
def _create_header(self, message_type, body):
99113
return struct.pack(self.header_format, message_type, len(body))
100114

101-
def _recv_message(self):
115+
def _read_with_timeout(self, server_pipe, message_size: int, timeout: Optional[float]):
116+
if timeout is None:
117+
return server_pipe.read(message_size)
118+
119+
start_t = time.time()
120+
message = b""
121+
122+
while message_size > 0:
123+
r, w, e = select.select([server_pipe], [], [], timeout)
124+
if server_pipe in r:
125+
part = os.read(server_pipe.fileno(), message_size)
126+
message_size -= len(part)
127+
message = message + part
128+
129+
cur_t = time.time()
130+
if timeout is not None and cur_t - start_t > timeout:
131+
break
132+
133+
if message_size != 0:
134+
raise TimeoutError(f"Reading from AI2-THOR backend timed out (using {timeout}s) timeout.")
135+
136+
return message
137+
138+
def _recv_message(self, timeout: Optional[float]):
102139
if self.server_pipe is None:
103140
self.server_pipe = open(self.server_pipe_path, "rb")
104141

105142
metadata = None
106143
files = defaultdict(list)
107144
while True:
108-
header = self.server_pipe.read(self.header_size) # message type + length
145+
header = self._read_with_timeout(
146+
server_pipe=self.server_pipe,
147+
message_size=self.header_size,
148+
timeout=self.timeout if timeout is None else timeout
149+
) # message type + length
109150
if len(header) == 0:
110151
self.unity_proc.wait(timeout=5)
111152
returncode = self.unity_proc.returncode
@@ -131,7 +172,13 @@ def _recv_message(self):
131172
# print("got header %s" % header)
132173
field_type_int, message_length = struct.unpack(self.header_format, header)
133174
field_type = self.field_types[field_type_int]
134-
body = self.server_pipe.read(message_length)
175+
176+
body = self._read_with_timeout(
177+
server_pipe=self.server_pipe,
178+
message_size=message_length,
179+
timeout=self.timeout if timeout is None else timeout
180+
)
181+
135182
# print("field type")
136183
# print(field_type)
137184
if field_type is FieldType.METADATA:
@@ -177,9 +224,11 @@ def _send_message(self, message_type, body):
177224
self.client_pipe.write(header + body + self.eom_header)
178225
self.client_pipe.flush()
179226

180-
def receive(self):
227+
def receive(self, timeout: Optional[float] = None):
181228

182-
metadata, files = self._recv_message()
229+
metadata, files = self._recv_message(
230+
timeout=self.timeout if timeout is None else timeout
231+
)
183232

184233
if metadata is None:
185234
raise ValueError("no metadata received from recv_message")
@@ -215,5 +264,11 @@ def unity_params(self):
215264
return params
216265

217266
def stop(self):
218-
self.client_pipe.close()
219-
self.server_pipe.close()
267+
if self.client_pipe is not None:
268+
self.client_pipe.close()
269+
270+
if self.server_pipe is not None:
271+
self.server_pipe.close()
272+
273+
if self.unity_proc is not None and self.unity_proc.poll() is None:
274+
self.unity_proc.kill()

0 commit comments

Comments
 (0)