diff --git a/nerfies/utils.py b/nerfies/utils.py index b4543f5..93438b8 100644 --- a/nerfies/utils.py +++ b/nerfies/utils.py @@ -334,7 +334,7 @@ def general_loss_with_squared_residual(squared_x, alpha, scale): def shard(xs, device_count=None): """Split data into shards for multiple devices along the first dimension.""" if device_count is None: - jax.local_device_count() + device_count = jax.local_device_count() return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)