From 2f5e6a18adf80d0e3db02b7a23d9d6cf8937ec82 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 24 Dec 2019 12:12:21 -0800 Subject: [PATCH] Fix keyword arguments in translation_moe task Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1546 Differential Revision: D19225548 Pulled By: myleott fbshipit-source-id: 43240cb90ca477ab7a790386ab2d9f4fd14e2625 --- fairseq/tasks/translation_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fairseq/tasks/translation_moe.py b/fairseq/tasks/translation_moe.py index ae8817a3..a4d980fb 100644 --- a/fairseq/tasks/translation_moe.py +++ b/fairseq/tasks/translation_moe.py @@ -121,13 +121,19 @@ def _get_loss(self, sample, model, criterion): bsz = sample['target'].size(0) def get_lprob_y(encoder_out, prev_output_tokens_k): - net_output = model.decoder(prev_output_tokens_k, encoder_out) + net_output = model.decoder( + prev_output_tokens=prev_output_tokens_k, + encoder_out=encoder_out, + ) loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) loss = loss.view(bsz, -1) return -loss.sum(dim=1, keepdim=True) # -> B x 1 def get_lprob_yz(winners=None): - encoder_out = model.encoder(sample['net_input']['src_tokens'], sample['net_input']['src_lengths']) + encoder_out = model.encoder( + src_tokens=sample['net_input']['src_tokens'], + src_lengths=sample['net_input']['src_lengths'], + ) if winners is None: lprob_y = []