-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' #198
Comments
Looks like they're not getting de/serialised correctly, so the If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement |
I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.
|
Do you have a MWE? |
Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state. `lr_scheduler = optax.warmup_cosine_decay_schedule( optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array)) checkpoint_params = { with open(checkpoint_params_file, "wb") as f: |
I am also encountering this issue, but only with pip install jaxtyping jax 'ray[default]' import jax
import ray
from jax import numpy as jnp
from jaxtyping import Int
ray.init()
@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
return x * 2
a = ray.put(jnp.arange(10))
ray.get(f.remote(a)) I tried implementing # jaxtyping/_array_types.py
@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
# ...
def __getstate__(cls):
return cls._get_props()
def __setstate__(cls, props):
(
cls.index_variadic,
cls.dims,
cls.array_type,
cls.dtypes,
cls.dim_str,
) = props
# ... But as best I can tell, neither one gets called at all. |
It looks like and then immediately tries to hash it: which fails, as this class does not yet have our attributes set.
Thank you for the MWE, that was invaluable to figure this one out! :) |
You guys, gals, and nonbinary pals rock!! |
Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0: ERROR:absl:Error occurred in child process with worker_index: 7 |
Ah, this has already been fixed and I just haven't done a new release for it yet. I've done a version bump + new release in #246 |
I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:
The text was updated successfully, but these errors were encountered: