Skip to content

Commit

Permalink
Moves the utils.host_to_global_device_array call from `SpmdTrainer.…
Browse files Browse the repository at this point in the history
…_run_step` to `SpmdTrainer.run`.

This makes it easier for subclasses of SpmdTrainer to override `_run_step`.
  • Loading branch information
ruomingp committed Nov 5, 2023
1 parent b2ccd7b commit c00c632
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _should_start_trace():
stop_trace_step = self.step + 3
self._step = self._step + 1
self.vlog(3, "Start step %s", self.step)
output = self._run_step(input_batch)
output = self._run_step(utils.host_to_global_device_array(input_batch))
self.vlog(3, "Done step %s", self.step)
num_steps += 1
if num_steps % 100 == 0:
Expand Down Expand Up @@ -688,13 +688,11 @@ def _run_step(self, input_batch: NestedTensor) -> NestedTensor:
"""Runs a single training step.
Args:
input_batch: a NestedTensor.
input_batch: a NestedTensor containing global arrays.
Returns:
A dict containing 'loss' and 'aux' outputs.
"""
input_batch = utils.host_to_global_device_array(input_batch)

with jax.profiler.StepTraceAnnotation("train", step_num=self.step):
# Note(Jan 2022):
# pjit currently requires all parameters to be specified as positional args.
Expand Down

0 comments on commit c00c632

Please # to comment.