Skip to content

Add masked LSTM support #2030

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

q-ycong-p
Copy link
Contributor

Masking is a intra-layer behavior in TF LSTM [1] but is not a intra-op behavior in ONNX LSTM [2]. When converted to ONNX, masked TF LSTM layer is converted to Loop op. This over-complicates the ONNX model, and has a negative impact on inference performance in ORT without leveraging LSTM optimizations. (issue #1871)

This commit adds support to convert masked LSTM correctly, under the important assumption that input must be post-padded - which is the most common use case. The "masking" info is conveyed to ONNX LSTM op as sequence_lens which is dynamically computed by summing the number of non-skip timesteps per batch per-LSTM. This behavior is implemented with reference to keras2onnx PR#386 [3]. Additional logic is added for backward LSTM so that the input sequence is reversed correctly given sequence_lens.

Note that if mask-enabled, and LSTM input is pre- or randomly padded, the converted ONNX model will behave incorrectly for inference. Unless ONNX add new attribute e.g. mask_enabled to RNN ops, converter alone may not be able to handle generic masking while keeping the RNN ops, since masking alters intra-op behavior. With such limitation, I'd like to share this PR for further comment and suggestion.

[1] https://www.tensorflow.org/guide/keras/masking_and_padding#masking
[2] https://github.com/onnx/onnx/blob/main/docs/Operators.md#LSTM
[3] onnx/keras-onnx#386


Details:

Forward LSTM

Here's an minimal example with an embedded LSTM (mask_zeros=True):

  • H5 model:

Screen Shot 2022-08-26 at 6 18 13 PM

  • tf2onnx-converted ONNX model, before proposed change:

Screen Shot 2022-08-26 at 6 17 49 PM

  • tf2onnx-converted ONNX model, after proposed change:

Screen Shot 2022-08-26 at 6 18 57 PM

Reverse LSTM

  • Need to alter tf.raw_op.ReverseV2->ReverseSequence behavior to reverse LSTM input correctly:

reverse_masked_lstm

Signed-off-by: Yu Cong <congyc@amazon.com>
@q-ycong-p
Copy link
Contributor Author

Sorry will address the test failures on TF-2.9 soon.

@AndreyOrb
Copy link

Hello,
Is there any progress with this issue?

@AndreyOrb
Copy link

Hi,
Is there any update? Will the proposed code work if pulled?

@AndreyOrb
Copy link

@xadupre Hello Xavier. I've been waiting for this PR for a few years now. Could you assist, please?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants