Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Error When Using Training Example on MPS Device #82

Open
sam-hey opened this issue Jan 26, 2025 · 1 comment
Open

Error When Using Training Example on MPS Device #82

sam-hey opened this issue Jan 26, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@sam-hey
Copy link

sam-hey commented Jan 26, 2025

When attempting to train on an MPS device, the following error occurs due to an issue with torch.compile():
pytorch/pytorch#96976

To resolve this issue, you can either remove torch.compile() or follow the provided error message by suppressing errors using the following code snippet:

import torch._dynamo
torch._dynamo.config.suppress_errors = True

By doing this, you can bypass the error and proceed with training on the MPS device.

model = torch.compile(model)

No sentence-transformers model found with name bert-base-uncased.
The checkpoint does not contain a linear projection layer. Adding one with output dimensions (768, 128).
Created a PyLate model from base encoder.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The tokenizer does not support resizing the token embeddings, the prefixes token have not been added to vocabulary.
  0%|                                                                                                                                                  | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 588, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1334, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5264, in var_mean
    return var_mean_helper_(
           ^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5249, in var_mean_helper_
    else var_mean_welford_(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5203, in var_mean_welford_
    mean, m2, _ = ir.WelfordReduction.create(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 1608, in create
    hint, split = Reduction.num_splits(
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 851, in num_splits
    not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 465, in has_feature
    return feature in self.get_backend_features(get_device_type(device))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 170, in get_backend_features
    return scheduling(None).get_backend_features(device)
           ^^^^^^^^^^^^^^^^
torch._inductor.exc.LoweringException: TypeError: 'NoneType' object is not callable
  target: aten.var_mean.correction
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
      'mps',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.load(primals_2, i1 + 32 * i0)
          tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
          tmp2 = ops.load(primals_1, i1 + 32 * i0)
          tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
          tmp4 = tmp1 + tmp3
          tmp5 = ops.load(primals_3, i1)
          tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
          tmp7 = tmp4 + tmp6
          return tmp7
      ,
      ranges=[2, 32, 768],
      origin_node=add_1,
      origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
    ))
  ))
  args[1]: [2]
  kwargs: {'correction': 0, 'keepdim': True}

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/A200009373/Documents/Coding/pylate/test2-add-embeddings.py", line 79, in <module>
    trainer.train()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/sentence_transformers/trainer.py", line 393, in compute_loss
    loss = loss_fn(features, labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/pylate/losses/distillation.py", line 83, in forward
    self.model(sentence_features[0])["token_embeddings"], p=2, dim=-1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: 'NoneType' object is not callable
  target: aten.var_mean.correction
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
      'mps',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.load(primals_2, i1 + 32 * i0)
          tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
          tmp2 = ops.load(primals_1, i1 + 32 * i0)
          tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
          tmp4 = tmp1 + tmp3
          tmp5 = ops.load(primals_3, i1)
          tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
          tmp7 = tmp4 + tmp6
          return tmp7
      ,
      ranges=[2, 32, 768],
      origin_node=add_1,
      origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
    ))
  ))
  args[1]: [2]
  kwargs: {'correction': 0, 'keepdim': True}

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
@NohTow
Copy link
Collaborator

NohTow commented Jan 27, 2025

Hello,

Yes, I think MPS does not support all the compilation backends and so it does not work properly.
The torch.compile(model) can also fail for other reasons (e.g, ModernBERT compile internally, so compile should not be called explicitely).
I encountered this message for various other reasons and I always comment out compiling when this happens, but this might be confusion for people that are not familiar, so I guess the best option is to add a comment everywhere (boilerplates and docs) that says to comment the line if this error message appear (as I would still like to let it be the default).

Also, I need to check, but I believe now the issue with compiling model is fixed in ST (cc @tomaarsen), so maybe it would be a great time to offload this operation to the ST parameter, because it would work better with models such as ModernBERT.

@NohTow NohTow self-assigned this Jan 27, 2025
@NohTow NohTow added the bug Something isn't working label Jan 27, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants