Skip to content

Commit

Permalink
Extended api of LSTM encoder. (#2030)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes #1862.
Nobody responded to my issue. Nevertheless change is very small, therefore i think is doesn't need much discussion.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: #2030

Reviewed By: joshim5, ngoyal2707

Differential Revision: D21584250

Pulled By: myleott

fbshipit-source-id: 28f0ccaca0df2860806178dbce02bcc12d7115d4
  • Loading branch information
MichalTurski authored and facebook-github-bot committed May 18, 2020
1 parent 803c0a6 commit 132ee8a
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions fairseq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,23 @@ def __init__(
if bidirectional:
self.output_units *= 2

def forward(self, src_tokens, src_lengths: Tensor):
def forward(
self,
src_tokens: Tensor,
src_lengths: Tensor,
enforce_sorted: bool = True,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of
shape `(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of
shape `(batch)`
enforce_sorted (bool, optional): if True, `src_tokens` is
expected to contain sequences sorted by length in a
decreasing order. If False, this condition is not
required. Default: True.
"""
if self.left_pad:
# nn.utils.rnn.pack_padded_sequence requires right-padding;
# convert left-padding to right-padding
Expand All @@ -250,7 +266,9 @@ def forward(self, src_tokens, src_lengths: Tensor):
x = x.transpose(0, 1)

# pack embedded source tokens into a PackedSequence
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data)
packed_x = nn.utils.rnn.pack_padded_sequence(
x, src_lengths.data, enforce_sorted=enforce_sorted
)

# apply LSTM
if self.bidirectional:
Expand Down

0 comments on commit 132ee8a

Please # to comment.