diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index f08657b793d7..c8d6a0bff444 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -556,7 +556,7 @@ def _aggregate_total_loss(self): agg_loss /= self.dp_world_size assert self.global_rank in self.grid.pp_group - losses = torch.stack([self.dp_group_loss, agg_loss]) + losses = torch.stack([self.dp_group_loss, agg_loss]).float() if self.is_pipe_parallel: dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: