From fb7ff4ff37dd437af8ed54f6abf41cd52d275011 Mon Sep 17 00:00:00 2001 From: Akhilesh Gotmare Date: Tue, 10 Sep 2019 17:04:54 +0800 Subject: [PATCH] Minor fix for >2d conv kernels missing .unsqueeze(-1) in line 124, without this change we'll encounter runtime error for >2d convolutional kernels, with this fix, we're applying adafactor's 2d logic to the two final dimensions. --- fairseq/optim/adafactor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 680ac371b9..00e6ed3138 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -121,7 +121,7 @@ def _rms(self, tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output): - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1) + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1).unsqueeze(-1)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() torch.mul(r_factor, c_factor, out=output)