Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Question about fullgraph=True compatibility ? #32

Closed
mitkotak opened this issue Nov 24, 2024 · 0 comments
Closed

Question about fullgraph=True compatibility ? #32

mitkotak opened this issue Nov 24, 2024 · 0 comments

Comments

@mitkotak
Copy link
Contributor

Was wondering whether there are any ways to bypass this ?

>>> import cuequivariance_torch as cuet
>>> import cuequivariance as cue
>>> m = cuet.FullyConnectedTensorProduct(irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"),
...                                     irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"),
...                                     irreps_out=cue.Irreps("O3", "32x0e + 32x1o"))
/home/mkotak/.local/lib/python3.10/site-packages/cuequivariance/irreps_array/misc_ui.py:25: UserWarning: layout is not specified, defaulting to cue.mul_ir. This is the layout used in the e3nn library. We use it as the default layout for compatibility with e3nn. However, the cue.ir_mul layout is faster and more memory efficient. Please specify the layout explicitly to avoid this warning.
  warnings.warn(
>>> import torch
>>> m_compile = torch.compile(m, fullgraph=True)
>>> input1 = torch.randn(100, m.irreps_in1.dim)
>>> input2 = torch.randn(100, m.irreps_in2.dim)
>>> m_compile(input1, input2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1744, in LOAD_ATTR
    self._load_attr(inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1734, in _load_attr
    result = BuiltinVariable(getattr).call_function(
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function
    return handler(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 711, in <lambda>
    return lambda tx, args, kwargs: obj.call_function(
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function
    return handler(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 848, in builtin_dispatch
    rv = fn(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 766, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1727, in call_getattr
    return obj.var_getattr(tx, name)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 1062, in var_getattr
    ).call_function(tx, [], {})
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1967, in UNPACK_SEQUENCE
    unimplemented(f"UNPACK_SEQUENCE {seq}")
  File "/home/mkotak/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: UNPACK_SEQUENCE FrozenDataClassVariable(MulIrrep)

from user code:
   File "/home/mkotak/.local/lib/python3.10/site-packages/cuequivariance_torch/operations/tp_fully_connected.py", line 151, in forward
    return self.f(weight, x1, x2, use_fallback=use_fallback)
  File "/home/mkotak/.local/lib/python3.10/site-packages/cuequivariance_torch/primitives/equivariant_tensor_product.py", line 153, in forward
    assert a.shape[-1] == b.irreps.dim
  File "/home/mkotak/.local/lib/python3.10/site-packages/cuequivariance/irreps_array/irreps.py", line 209, in dim
    return sum(mul * rep.dim for mul, rep in self)
  File "/home/mkotak/.local/lib/python3.10/site-packages/cuequivariance/irreps_array/irreps.py", line 209, in <genexpr>
    return sum(mul * rep.dim for mul, rep in self)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

>>> 
@mitkotak mitkotak changed the title fullgraph=True compatibility ? Question about fullgraph=True compatibility ? Nov 24, 2024
mariogeiger added a commit that referenced this issue Dec 2, 2024
* avoid calling irreps.dim and logger in forward

* add tests

* optimize_fallback=True
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant