diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 680ac371b9..00e6ed3138 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -121,7 +121,7 @@ def _rms(self, tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output): - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1) + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1).unsqueeze(-1)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() torch.mul(r_factor, c_factor, out=output)