You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to finetune wav2vec2 model along with bart model on my custom dataset using the following command python run_flax_speech_recognition_seq2seq.py ... but I got this error.
main()
File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
state, train_metric = p_train_step(state, batch)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 678, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 825, in lower_parallel_callable
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 748, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2233, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
check_arraylike("jnp.linalg.norm", x)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_flax_speech_recognition_seq2seq.py", line 1266, in <module>
main()
File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
state, train_metric = p_train_step(state, batch)
File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
check_arraylike("jnp.linalg.norm", x)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
jax version is jax==0.4.13
jaxlib version is jaxlib==0.4.13
flax version is flax==0.7.2
The text was updated successfully, but these errors were encountered:
I tried to finetune wav2vec2 model along with bart model on my custom dataset using the following command
python run_flax_speech_recognition_seq2seq.py ...
but I got this error.jax version is
jax==0.4.13
jaxlib version is
jaxlib==0.4.13
flax version is
flax==0.7.2
The text was updated successfully, but these errors were encountered: