Skip to content

Commit

Permalink
Fix keyword arguments in translation_moe task
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch/fairseq#1546

Differential Revision: D19225548

Pulled By: myleott

fbshipit-source-id: 43240cb90ca477ab7a790386ab2d9f4fd14e2625
  • Loading branch information
Myle Ott authored and yzpang committed Feb 19, 2021
1 parent d42d6cf commit 2f5e6a1
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions fairseq/tasks/translation_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,19 @@ def _get_loss(self, sample, model, criterion):
bsz = sample['target'].size(0)

def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(prev_output_tokens_k, encoder_out)
net_output = model.decoder(
prev_output_tokens=prev_output_tokens_k,
encoder_out=encoder_out,
)
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
loss = loss.view(bsz, -1)
return -loss.sum(dim=1, keepdim=True) # -> B x 1

def get_lprob_yz(winners=None):
encoder_out = model.encoder(sample['net_input']['src_tokens'], sample['net_input']['src_lengths'])
encoder_out = model.encoder(
src_tokens=sample['net_input']['src_tokens'],
src_lengths=sample['net_input']['src_lengths'],
)

if winners is None:
lprob_y = []
Expand Down

0 comments on commit 2f5e6a1

Please # to comment.