Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix typo error for lstm operations #385

Merged
merged 1 commit into from
Jun 20, 2023

Conversation

BruceDai
Copy link
Contributor

@BruceDai BruceDai commented May 8, 2023

The fix is verified by the webnn-polyfill lstm and lstmCell implementations and tests: webmachinelearning/webnn-polyfill#227.
@wchao1115 @fdwr @huningxin PTAL, thanks.


Preview | Diff

index.bs Outdated
@@ -2201,7 +2201,7 @@ partial interface MLGraphBuilder {
- *options*: an optional {{MLLstmCellOptions}}. The optional parameters of the operation.
- *bias*: an {{MLOperand}}. The 1-D input bias tensor of shape [4 * hidden_size]. The ordering of the bias vectors in the first dimension of the tensor shape is specified according to the *options.layout* argument.
- *recurrentBias*: an {{MLOperand}}. The 1-D recurrent bias tensor of shape [4 * hidden_size]. The ordering of the bias vectors in the first dimension of the tensor shape is specified according to the *options.layout* argument.
- *peepholeWeight*: an {{MLOperand}}. The 1-D weight tensor for peepholes of shape [3 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
- *peepholeWeight*: an {{MLOperand}}. The 1-D weight tensor for peepholes of shape [4 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
Copy link
Collaborator

@fdwr fdwr May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not sure about TF and PT (I can't figure out yet which one corresponds to the "peephole", since they evidently use a different term), but the DML API definitely shows PeepholeTensor sizes = { 1, 1, num_directions, 3 * hidden_size } and ONNX too P (optional, differentiable) : The weight tensor for peepholes. ... It has shape [num_directions, 3*hidden_size], not 4?

That would be consistent with the wording after it, concatenating 3 tensors: The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.

Copy link
Contributor Author

@BruceDai BruceDai May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fdwr!

I've updated commit with referring to above your sharing materials and also this arXiv paper LSTM: A Search Space Odyssey, and updated implementation of lstm for WebNN-Polyfill API webmachinelearning/webnn-polyfill@b25e817, please take another look, thanks.

In the Spec, lstm operation describes peepholeWeight as bellowing:

peepholeWeight: an [MLOperand](https://webmachinelearning.github.io/webnn/#mloperand). 
The 2-D weight tensor for peepholes of shape [num_directions, 4 * hidden_size]. 
The pack ordering of the weight vectors is for the input (i), output (o), and forget (f) gate respectively.

So I did this modification for lstmCell operation to align with the second value of above shape [num_directions, 4 * hidden_size].
This is a typo error of peepholeWeight shape being [num_directions, 4 * hidden_size], it should be [num_directions, 3 * hidden_size], then it has the shape like next explanation line "The pack ordering of the weight vectors is for the input (i), output (o), and forget (f) gate respectively." .

And this shape [num_directions, 3 * hidden_size] could be verified by this much clearer descriptive part II. VANILLA LSTM paper of LSTM: A Search Space Odyssey, please see these three screenshots:

  • figure-1

image

  • figure-2
    image

  • figure-3

image

but the DML API definitely shows PeepholeTensor sizes = { 1, 1, num_directions, 3 * hidden_size } and ONNX too P (optional, differentiable) : The weight tensor for peepholes. ... It has shape [num_directions, 3*hidden_size]

I verified that [3 * hidden_size] couldn't be the shape of peepholeWeight for lstm operation with the error when executing slice compute, while the shape of peepholeWeight for lstmCell operation being [3 * hidden_size] works.
The error when executing slice compute is due to not modify currentPeepholeWeight setting of for loop as below

    for (let dir = 0; dir < numDirections; ++dir) {
      .....
      currentPeepholeWeight.push(options.peepholeWeight ?
 -      (builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
 +      (builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null);
    }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, you had to dig back into the original paper. Thank you for the spec wording and pseudocode fix.

@BruceDai BruceDai force-pushed the fix_lstm_lstmCell branch from 006af72 to 1732b82 Compare May 22, 2023 14:33
@BruceDai BruceDai changed the title Fix error for lstm and lstmCell operations Fix typo error for lstm operations May 22, 2023
Copy link
Collaborator

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😎

@wchao1115 wchao1115 self-requested a review June 5, 2023 17:35
Copy link
Collaborator

@wchao1115 wchao1115 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks right. Thanks for fixing it.

@anssiko anssiko requested a review from huningxin June 5, 2023 18:12
@anssiko
Copy link
Member

anssiko commented Jun 5, 2023

@huningxin feel free to merge at will if you're happy with this fix.

Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @BruceDai !

@wchao1115
Copy link
Collaborator

@BruceDai What is the reason this PR hasn't been merged? It was approved a few weeks ago?

@BruceDai
Copy link
Contributor Author

@anssiko @huningxin Would you please help merge this PR, thanks.

@anssiko anssiko merged commit 54d57b4 into webmachinelearning:main Jun 20, 2023
@anssiko
Copy link
Member

anssiko commented Jun 20, 2023

@BruceDai thanks for this contribution!

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants