Skip to content

Commit

Permalink
convergence bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 3, 2024
1 parent f34f7fd commit 01037d2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
5 changes: 2 additions & 3 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,19 @@ def sync_gradients(
# so we only need the reduction over the data parallel group
dist.all_reduce(grad, group=data_parallel_group)
if mean:
grad.div_(G_data * G_intra_d)
grad.div_(torch.distributed.get_world_size())

for grad in grads_to_sync["tensor_parallel_biases"]:
# biases need to be reduced over both the data parallel
# and depth parallel groups
dist.all_reduce(grad, group=data_parallel_group)
dist.all_reduce(grad, group=depth_parallel_group)
if mean:
grad.div_(G_data * G_intra_d)
grad.div_(torch.distributed.get_world_size())

for grad in grads_to_sync["others"]:
# all other weights are purely data parallel
dist.all_reduce(grad)
if mean:
grad.div_(torch.distributed.get_world_size())


16 changes: 8 additions & 8 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,17 @@ def backward(ctx, grad_output):
.mm(input_.view(-1, input_.shape[-1]))
)

grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=overlap_reduce_scatter,
)
grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=overlap_reduce_scatter,
)

if handle and overlap_all_reduce:
handle.wait()
if overlap_reduce_scatter:
if overlap_reduce_scatter and ctx.needs_input_grad[1]:
overlap_communication.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet
return grad_input, grad_weight, None, None, None, None, None, None, None
Expand Down

0 comments on commit 01037d2

Please # to comment.