diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 52ed5fb88..d59bfe479 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -360,7 +360,7 @@ def tensor( elif str(type(value)) == "": # NOTE: We use str(type(...)) and do not import torch for type checking # as it creates overhead during import - return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) + return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type] elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): return _core.Tensor(value, dtype=dtype, name=name, doc_string=name) diff --git a/onnxscript/ir/_convenience_test.py b/onnxscript/ir/_convenience_test.py index 37b08608f..c293a0097 100644 --- a/onnxscript/ir/_convenience_test.py +++ b/onnxscript/ir/_convenience_test.py @@ -11,7 +11,7 @@ class ConvenienceTest(unittest.TestCase): def test_tensor_accepts_torch_tensor(self): - import torch as some_random_name + import torch as some_random_name # pylint: disable=import-outside-toplevel torch_tensor = some_random_name.tensor([1, 2, 3]) tensor = _convenience.tensor(torch_tensor)