diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 5f1360a5..042de797 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -417,7 +417,7 @@ def _should_start_trace(): stop_trace_step = self.step + 3 self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - output = self._run_step(input_batch) + output = self._run_step(utils.host_to_global_device_array(input_batch)) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 100 == 0: @@ -688,13 +688,11 @@ def _run_step(self, input_batch: NestedTensor) -> NestedTensor: """Runs a single training step. Args: - input_batch: a NestedTensor. + input_batch: a NestedTensor containing global arrays. Returns: A dict containing 'loss' and 'aux' outputs. """ - input_batch = utils.host_to_global_device_array(input_batch) - with jax.profiler.StepTraceAnnotation("train", step_num=self.step): # Note(Jan 2022): # pjit currently requires all parameters to be specified as positional args.