Skip to content

Commit

Permalink
fix: to solve #4726 (#4727)
Browse files Browse the repository at this point in the history
To solve #4726 , I change the dtype of loss tensor into float32 in the
last stage of pipeline.

**test result**
before dist.broadcast

```
[2023-11-24 14:06:04,709] [INFO] [engine.py:590:_aggregate_total_loss] [Rank 2] before dist.broadcast(is_last_stage) (tensor([2.3203, 2.3203], device='cuda:2'), torch.float32, device(type='cuda', index=2)), src_rank=2 (1, 2)
[2023-11-24 14:06:04,710] [INFO] [engine.py:590:_aggregate_total_loss] [Rank 3] before dist.broadcast(is_last_stage) (tensor([2.3203, 2.3203], device='cuda:3'), torch.float32, device(type='cuda', index=3)), src_rank=3 (1, 2)
```

After dist.broadcast, you can see the broadcast result is correct
between rank 2 and rank 0 as well as rank 3 and rank 1.
```
[2023-11-24 14:06:05,016] [INFO] [engine.py:608:_aggregate_total_loss] [Rank 1] after dist.broadcast(other stage) (tensor([2.3203, 2.3203], device='cuda:1'), torch.float32)
[2023-11-24 14:06:05,043] [INFO] [engine.py:608:_aggregate_total_loss] [Rank 0] after dist.broadcast(other stage) (tensor([2.3203, 2.3203], device='cuda:0'), torch.float32)
```

For more information. please refer #4726.

Co-authored-by: ryan <ruanzhixiang1@huawei.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
3 people authored Nov 30, 2023
1 parent 8640b8e commit 02288bc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _aggregate_total_loss(self):
agg_loss /= self.dp_world_size

assert self.global_rank in self.grid.pp_group
losses = torch.stack([self.dp_group_loss, agg_loss])
losses = torch.stack([self.dp_group_loss, agg_loss]).float()
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
Expand Down

0 comments on commit 02288bc

Please # to comment.