-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
34 lines (31 loc) · 1000 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.multiprocessing as mp
from _thread import start_new_thread
from functools import wraps
import traceback
def prepare_mp(graph):
graph.in_degrees(0)
graph.out_degrees(0)
graph.find_edges([0])
def fix_openmp(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function