From 0097ea387db3b90f57910bbfb927cb6dcccab6d3 Mon Sep 17 00:00:00 2001 From: JamesPerlman Date: Mon, 14 Feb 2022 17:39:57 -0800 Subject: [PATCH] Save result from jax.local_device_count --- nerfies/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/utils.py b/nerfies/utils.py index b4543f5..93438b8 100644 --- a/nerfies/utils.py +++ b/nerfies/utils.py @@ -334,7 +334,7 @@ def general_loss_with_squared_residual(squared_x, alpha, scale): def shard(xs, device_count=None): """Split data into shards for multiple devices along the first dimension.""" if device_count is None: - jax.local_device_count() + device_count = jax.local_device_count() return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)