Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

error while finetuning wav2vec2 with bart. #96

Open
arvindmn01 opened this issue Aug 28, 2023 · 1 comment
Open

error while finetuning wav2vec2 with bart. #96

arvindmn01 opened this issue Aug 28, 2023 · 1 comment

Comments

@arvindmn01
Copy link

arvindmn01 commented Aug 28, 2023

I tried to finetune wav2vec2 model along with bart model on my custom dataset using the following command
python run_flax_speech_recognition_seq2seq.py ... but I got this error.

    main()
  File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
    state, train_metric = p_train_step(state, batch)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 678, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 825, in lower_parallel_callable
    jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 748, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2233, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
    "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
    check_arraylike("jnp.linalg.norm", x)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "run_flax_speech_recognition_seq2seq.py", line 1266, in <module>
    main()
  File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
    state, train_metric = p_train_step(state, batch)
  File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
    "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
    check_arraylike("jnp.linalg.norm", x)
  File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

jax version is jax==0.4.13
jaxlib version is jaxlib==0.4.13
flax version is flax==0.7.2

@sanchit-gandhi
Copy link
Owner

Hey @arvindmn01 - could you provide a reproducible code snippet for this error?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants