-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Include MultiheadAttention module in C# API #320
Conversation
It seems TorchSharp TorchText is now making critical progress.. Thanks!! |
@fwaris is there a plan to include unit tests and examples for MultiHeadAttention? |
I have a unit test that is passing. It exercises the basic functionality. I am working on porting the Temporal Graph Network model which requires MHA. Once done, I can create an example also. |
attn_mask?.Handle ?? IntPtr.Zero, | ||
out var res1, | ||
out var res2); | ||
if (res1 == IntPtr.Zero) { torch.CheckForErrors(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should probably check both res1 and res2, just to be safe.
/// <param name="kdim">total number of features in key</param> | ||
/// <param name="vdim">total number of features in value</param> | ||
/// <returns></returns> | ||
static public MultiheadAttention MultiheadAttention(long embeded_dim, long num_heads, double dropout = 0.0, bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, long? kdim=null, long? vdim=null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'embeded_dim' -> 'embedded_dim'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will fix. thanks
This is a good addition to TorchSharp, and we really value your contribution. We're in the middle of moving this repository to a different organization (Xamarin is not the right long-term place for it), where we will have a proper Contributor License Agreement for all contributions, internal and external. Therefore, we will hold off on merging this PR for a little bit. |
An update: this repo will be moved to the .NET Foundation organization, after which we will have the CLA and can accept the contribution. |
@fwaris, if you resubmit this PR, GitHub will take you through the signing of a CLA, and we can accept your PR. Also, the text for the copyright header has changed from 'Microsoft Corp.' to '.NET Foundation.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor changes required before the PR will be accepted.
@@ -0,0 +1,79 @@ | |||
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should now say 'Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.'
/// <returns>attn_output, attn_ouput_weights</returns> | ||
|
||
public Tuple<Tensor,Tensor> forward(Tensor query, Tensor key, Tensor value, Tensor? key_padding_mask = null, bool need_weights = true, Tensor? attn_mask = null) | ||
//const NNModule module, const Tensor query, const Tensor key, const Tensor value, const Tensor key_padding_mask, const bool need_weights, const Tensor attn_mask, Tensor res1, Tensor res2 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's commented code,
public Tuple<Tensor,Tensor> forward(Tensor query, Tensor key, Tensor value, Tensor? key_padding_mask = null, bool need_weights = true, Tensor? attn_mask = null) | ||
//const NNModule module, const Tensor query, const Tensor key, const Tensor value, const Tensor key_padding_mask, const bool need_weights, const Tensor attn_mask, Tensor res1, Tensor res2 ) | ||
{ | ||
//var res1 = IntPtr.Zero; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More commented code.
@@ -268,3 +268,4 @@ packages/ | |||
.ionide | |||
*.bin | |||
/*.png | |||
/src/Native/out/build/x64-Debug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Odd. Was this necessary?
i signed the CLA but not sure if the PR is still valid. If not, create a new one later this week. |
I closed and resubmitted the PR, so now the checks are running. I had some review comments. |
@fwaris, I would like to get this in and then do a NuGet release with it, so I'm going to merge the PR and address the requested changes in another PR. |
thanks. I will rebase and make the request changes as another PR |
@fwaris, there's no need. I already made the changes and they are in main now. |
@fwaris We are attempting to visualize MultiheadAttention module within the transformer architecture. You input is appreciated. :-) I still looking forwards to your tests and samples on MultiheadAttention :-) |
day job is keeping me very busy but I intend to complete the TGN model port (see above) port in the next few weeks. Only a few modules left to port. |
@fwaris take your time. Thx for valuable voluntary contribution :-) |
We need the MultiheadAttention module by itself for some types of models.
Although MultiheadAttention is part of the transformer modules, it is also needed separately (i.e. outside of transformer) for some models.