From ef28ceb4fe51e24ef37c0be57bc369c28173d0c8 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:31:15 -0400 Subject: [PATCH] Added sharding parameter: fix for JAX 0.4.31 (#25) --- quax/examples/zero/_core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/quax/examples/zero/_core.py b/quax/examples/zero/_core.py index 407141a..b8087c3 100644 --- a/quax/examples/zero/_core.py +++ b/quax/examples/zero/_core.py @@ -64,8 +64,9 @@ def _(value: Zero, *, broadcast_dimensions, shape) -> Zero: @quax.register(lax.convert_element_type_p) -def _(value: Zero, *, new_dtype, weak_type) -> Zero: - del weak_type +def _(value: Zero, *, new_dtype, weak_type, sharding=None) -> Zero: + # sharding was added around JAX 0.4.31, it seems. + del weak_type, sharding return Zero(value.shape, new_dtype)