-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_wrapper.py
63 lines (52 loc) · 1.97 KB
/
test_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import _setup_test_env # noqa
import sys
import unittest
import typing
import numpy
import torch
from pytorch_to_returnn import torch as torch_returnn
from pytorch_to_returnn.import_wrapper import wrapped_import_torch_traced
def assert_equal(msg, a, b):
print(f"check {msg} {a!r} == {b!r}")
assert a == b
def test_torch_traced_wrapped_tensor():
from pytorch_to_returnn.import_wrapper.torch_wrappers.tensor import WrappedTorchTensor
torch_traced = wrapped_import_torch_traced("torch")
assert torch.Tensor is not torch_traced.Tensor
assert torch_traced.Tensor is WrappedTorchTensor
x_ = torch_traced.Tensor()
assert isinstance(x_, torch_traced.Tensor)
assert isinstance(x_, WrappedTorchTensor)
def test_torch_traced_from_numpy():
from pytorch_to_returnn.import_wrapper.torch_wrappers.tensor import WrappedTorchTensor
torch_traced = wrapped_import_torch_traced("torch")
x = torch_traced.from_numpy(numpy.array([1, 2, 3]))
assert isinstance(x, torch_traced.Tensor)
assert isinstance(x, WrappedTorchTensor)
def test_torch_traced_wrapped_tensor_new():
torch_traced = wrapped_import_torch_traced("torch")
x = torch.Tensor()
x_ = torch_traced.Tensor()
assert_equal("new_zeros(3).shape:", x.new_zeros(3).shape, x_.new_zeros(3).shape)
assert_equal("new_zeros([3]).shape:", x.new_zeros([3]).shape, x_.new_zeros([3]).shape)
assert_equal("new(3).shape:", x.new(3).shape, x_.new(3).shape)
if __name__ == "__main__":
if len(sys.argv) <= 1:
for k, v in list(globals().items()):
if k.startswith("test_"):
print("-" * 40)
print("Executing: %s" % k)
try:
v()
except unittest.SkipTest as exc:
print("SkipTest:", exc)
print("-" * 40)
print("Finished all tests.")
else:
assert len(sys.argv) >= 2
for arg in sys.argv[1:]:
print("Executing: %s" % arg)
if arg in globals():
globals()[arg]() # assume function and execute
else:
eval(arg) # assume Python code and execute