From c00c632b99e6a2d87ee7ba94f295b39e0871a577 Mon Sep 17 00:00:00 2001 From: Ruoming Pang Date: Sun, 5 Nov 2023 10:00:33 -0500 Subject: [PATCH] Moves the `utils.host_to_global_device_array` call from `SpmdTrainer._run_step` to `SpmdTrainer.run`. This makes it easier for subclasses of SpmdTrainer to override `_run_step`. --- axlearn/common/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 5f1360a5..042de797 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -417,7 +417,7 @@ def _should_start_trace(): stop_trace_step = self.step + 3 self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - output = self._run_step(input_batch) + output = self._run_step(utils.host_to_global_device_array(input_batch)) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 100 == 0: @@ -688,13 +688,11 @@ def _run_step(self, input_batch: NestedTensor) -> NestedTensor: """Runs a single training step. Args: - input_batch: a NestedTensor. + input_batch: a NestedTensor containing global arrays. Returns: A dict containing 'loss' and 'aux' outputs. """ - input_batch = utils.host_to_global_device_array(input_batch) - with jax.profiler.StepTraceAnnotation("train", step_num=self.step): # Note(Jan 2022): # pjit currently requires all parameters to be specified as positional args.