Commit f5d9142 1 parent b93e2ec commit f5d9142 Copy full SHA for f5d9142
File tree 1 file changed +11
-0
lines changed
1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -356,12 +356,23 @@ def error_on_warning():
356
356
yield
357
357
358
358
359
+ def get_physical_device_indices (devices ):
360
+ visible_devices = os .environ .get ("CUDA_VISIBLE_DEVICES" )
361
+ if visible_devices is None :
362
+ return devices
363
+
364
+ visible_indices = [int (x ) for x in visible_devices .split ("," )]
365
+ index_mapping = {i : physical for i , physical in enumerate (visible_indices )}
366
+ return [index_mapping [i ] for i in devices if i in index_mapping ]
367
+
368
+
359
369
@_nvml ()
360
370
def wait_for_gpu_memory_to_clear (devices : List [int ],
361
371
threshold_bytes : int ,
362
372
timeout_s : float = 120 ) -> None :
363
373
# Use nvml instead of pytorch to reduce measurement error from torch cuda
364
374
# context.
375
+ devices = get_physical_device_indices (devices )
365
376
start_time = time .time ()
366
377
while True :
367
378
output : Dict [int , str ] = {}
You can’t perform that action at this time.
0 commit comments