You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When attempting to train on an MPS device, the following error occurs due to an issue with torch.compile(): pytorch/pytorch#96976
To resolve this issue, you can either remove torch.compile() or follow the provided error message by suppressing errors using the following code snippet:
No sentence-transformers model found with name bert-base-uncased.
The checkpoint does not contain a linear projection layer. Adding one with output dimensions (768, 128).
Created a PyLate model from base encoder.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The tokenizer does not support resizing the token embeddings, the prefixes token have not been added to vocabulary.
0%| | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 588, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1334, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5264, in var_mean
return var_mean_helper_(
^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5249, in var_mean_helper_
else var_mean_welford_(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5203, in var_mean_welford_
mean, m2, _ = ir.WelfordReduction.create(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 1608, in create
hint, split = Reduction.num_splits(
^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 851, in num_splits
not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 465, in has_feature
return feature in self.get_backend_features(get_device_type(device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 170, in get_backend_features
return scheduling(None).get_backend_features(device)
^^^^^^^^^^^^^^^^
torch._inductor.exc.LoweringException: TypeError: 'NoneType' object is not callable
target: aten.var_mean.correction
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
'mps',
torch.float32,
def inner_fn(index):
i0, i1, i2 = index
tmp0 = ops.load(primals_2, i1 + 32 * i0)
tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
tmp2 = ops.load(primals_1, i1 + 32 * i0)
tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
tmp4 = tmp1 + tmp3
tmp5 = ops.load(primals_3, i1)
tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
tmp7 = tmp4 + tmp6
return tmp7
,
ranges=[2, 32, 768],
origin_node=add_1,
origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
))
))
args[1]: [2]
kwargs: {'correction': 0, 'keepdim': True}
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/A200009373/Documents/Coding/pylate/test2-add-embeddings.py", line 79, in <module>
trainer.train()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/sentence_transformers/trainer.py", line 393, in compute_loss
loss = loss_fn(features, labels)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/pylate/losses/distillation.py", line 83, in forward
self.model(sentence_features[0])["token_embeddings"], p=2, dim=-1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
self._return(inst)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
self.output.compile_subgraph(
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: 'NoneType' object is not callable
target: aten.var_mean.correction
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
'mps',
torch.float32,
def inner_fn(index):
i0, i1, i2 = index
tmp0 = ops.load(primals_2, i1 + 32 * i0)
tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
tmp2 = ops.load(primals_1, i1 + 32 * i0)
tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
tmp4 = tmp1 + tmp3
tmp5 = ops.load(primals_3, i1)
tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
tmp7 = tmp4 + tmp6
return tmp7
,
ranges=[2, 32, 768],
origin_node=add_1,
origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
))
))
args[1]: [2]
kwargs: {'correction': 0, 'keepdim': True}
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
The text was updated successfully, but these errors were encountered:
Yes, I think MPS does not support all the compilation backends and so it does not work properly.
The torch.compile(model) can also fail for other reasons (e.g, ModernBERT compile internally, so compile should not be called explicitely).
I encountered this message for various other reasons and I always comment out compiling when this happens, but this might be confusion for people that are not familiar, so I guess the best option is to add a comment everywhere (boilerplates and docs) that says to comment the line if this error message appear (as I would still like to let it be the default).
Also, I need to check, but I believe now the issue with compiling model is fixed in ST (cc @tomaarsen), so maybe it would be a great time to offload this operation to the ST parameter, because it would work better with models such as ModernBERT.
When attempting to train on an MPS device, the following error occurs due to an issue with
torch.compile()
:pytorch/pytorch#96976
To resolve this issue, you can either remove torch.compile() or follow the provided error message by suppressing errors using the following code snippet:
By doing this, you can bypass the error and proceed with training on the MPS device.
pylate/README.md
Line 165 in e349918
The text was updated successfully, but these errors were encountered: