19
19
import shutil
20
20
import subprocess
21
21
import time
22
+ import traceback
22
23
import uuid
23
24
import warnings
24
25
from collections import defaultdict , deque
25
26
from functools import lru_cache
26
27
from itertools import product
27
28
from platform import architecture as platform_architecture
28
29
from platform import system as platform_system
30
+ from typing import Dict , Any , Union , Optional
29
31
30
32
import numpy as np
31
33
@@ -392,6 +394,8 @@ def __init__(
392
394
server_class = None ,
393
395
gpu_device = None ,
394
396
platform = None ,
397
+ server_timeout : Optional [float ] = 100.0 ,
398
+ server_start_timeout : float = 300.0 ,
395
399
** unity_initialization_parameters ,
396
400
):
397
401
self .receptacle_nearest_pivot_points = {}
@@ -401,6 +405,11 @@ def __init__(
401
405
self .width = width
402
406
self .height = height
403
407
408
+ self .server_timeout = server_timeout
409
+ self .server_start_timeout = server_start_timeout
410
+ assert self .server_timeout is None or 0 < self .server_timeout
411
+ assert 0 < self .server_start_timeout
412
+
404
413
self .last_event = None
405
414
self .scene = None
406
415
self ._scenes_in_build = None
@@ -577,7 +586,7 @@ def __init__(
577
586
init_return = event .metadata ["actionReturn" ]
578
587
if init_return :
579
588
self .server .set_init_params (init_return )
580
- logging .info ("Initialize return: {}" . format ( init_return ) )
589
+ logging .info (f "Initialize return: { init_return } " )
581
590
582
591
def _build_server (self , host , port , width , height ):
583
592
@@ -593,20 +602,22 @@ def _build_server(self, host, port, width, height):
593
602
594
603
if self .server_class == ai2thor .wsgi_server .WsgiServer :
595
604
self .server = ai2thor .wsgi_server .WsgiServer (
596
- host ,
605
+ host = host ,
606
+ timeout = self .server_timeout ,
597
607
port = port ,
598
- depth_format = self .depth_format ,
599
- add_depth_noise = self .add_depth_noise ,
600
608
width = width ,
601
609
height = height ,
610
+ depth_format = self .depth_format ,
611
+ add_depth_noise = self .add_depth_noise ,
602
612
)
603
613
604
614
elif self .server_class == ai2thor .fifo_server .FifoServer :
605
615
self .server = ai2thor .fifo_server .FifoServer (
606
- depth_format = self .depth_format ,
607
- add_depth_noise = self .add_depth_noise ,
608
616
width = width ,
609
617
height = height ,
618
+ timeout = self .server_timeout ,
619
+ depth_format = self .depth_format ,
620
+ add_depth_noise = self .add_depth_noise ,
610
621
)
611
622
612
623
def __enter__ (self ):
@@ -913,9 +924,9 @@ def multi_step_physics(self, action, timeStep=0.05, max_steps=20):
913
924
914
925
return events
915
926
916
- def step (self , action = None , ** action_args ):
927
+ def step (self , action : Union [ str , Dict [ str , Any ]] = None , ** action_args ):
917
928
918
- if type (action ) is dict :
929
+ if isinstance (action , Dict ) :
919
930
action = copy .deepcopy (action ) # prevent changes from leaking
920
931
else :
921
932
action = dict (action = action )
@@ -960,18 +971,27 @@ def step(self, action=None, **action_args):
960
971
self .server .send (action )
961
972
try :
962
973
self .last_event = self .server .receive ()
963
- except UnityCrashException as e :
974
+ except UnityCrashException :
964
975
self .server .stop ()
965
976
self .server = None
966
977
# we don't need to pass port or host, since this Exception
967
978
# is only thrown from the FifoServer, start_unity is also
968
979
# not passed since Unity would have to have been started
969
980
# for this to be thrown
970
- message = "Restarting unity due to crash: %s" % e
981
+ message = (
982
+ f"Restarting unity due to crash when when running action { action } "
983
+ f" in scene { self .last_event .metadata ['sceneName' ]} :\n { traceback .format_exc ()} "
984
+ )
971
985
warnings .warn (message )
972
986
self .start (width = self .width , height = self .height , x_display = self .x_display )
973
987
self .reset ()
974
988
raise RestartError (message )
989
+ except Exception as e :
990
+ self .server .stop ()
991
+ raise (TimeoutError if isinstance (e , TimeoutError ) else RuntimeError )(
992
+ f"Error encountered when running action { action } "
993
+ f" in scene { self .last_event .metadata ['sceneName' ]} ."
994
+ )
975
995
976
996
if not self .last_event .metadata ["lastActionSuccess" ]:
977
997
if self .last_event .metadata ["errorCode" ] in [
@@ -1056,7 +1076,7 @@ def _start_unity_thread(self, env, width, height, server_params, image_name):
1056
1076
pass
1057
1077
1058
1078
self .unity_pid = proc .pid
1059
- atexit .register (lambda : proc . poll () is None and proc . kill ())
1079
+ atexit .register (lambda : self . server . stop ())
1060
1080
1061
1081
@property
1062
1082
def tmp_dir (self ):
@@ -1336,7 +1356,7 @@ def start(
1336
1356
self ._start_unity_thread (env , width , height , unity_params , image_name )
1337
1357
1338
1358
# receive the first request
1339
- self .last_event = self .server .receive ()
1359
+ self .last_event = self .server .receive (timeout = self . server_start_timeout )
1340
1360
1341
1361
# we should be able to get rid of this since we check the resolution in .reset()
1342
1362
if self .server .unity_proc is not None and (height < 300 or width < 300 ):
0 commit comments