Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Include MultiheadAttention module in C# API #320

Merged
merged 5 commits into from
Sep 2, 2021
Merged

Include MultiheadAttention module in C# API #320

merged 5 commits into from
Sep 2, 2021

Conversation

fwaris
Copy link
Contributor

@fwaris fwaris commented Aug 4, 2021

We need the MultiheadAttention module by itself for some types of models.

Although MultiheadAttention is part of the transformer modules, it is also needed separately (i.e. outside of transformer) for some models.

@GeorgeS2019
Copy link

GeorgeS2019 commented Aug 4, 2021

It seems TorchSharp TorchText is now making critical progress.. Thanks!!

@GeorgeS2019
Copy link

@fwaris is there a plan to include unit tests and examples for MultiHeadAttention?

@fwaris
Copy link
Contributor Author

fwaris commented Aug 4, 2021

I have a unit test that is passing. It exercises the basic functionality.

I am working on porting the Temporal Graph Network model which requires MHA. Once done, I can create an example also.

attn_mask?.Handle ?? IntPtr.Zero,
out var res1,
out var res2);
if (res1 == IntPtr.Zero) { torch.CheckForErrors(); }
Copy link
Contributor

@NiklasGustafsson NiklasGustafsson Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should probably check both res1 and res2, just to be safe.

/// <param name="kdim">total number of features in key</param>
/// <param name="vdim">total number of features in value</param>
/// <returns></returns>
static public MultiheadAttention MultiheadAttention(long embeded_dim, long num_heads, double dropout = 0.0, bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, long? kdim=null, long? vdim=null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'embeded_dim' -> 'embedded_dim'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix. thanks

@NiklasGustafsson
Copy link
Contributor

@fwaris:

This is a good addition to TorchSharp, and we really value your contribution.

We're in the middle of moving this repository to a different organization (Xamarin is not the right long-term place for it), where we will have a proper Contributor License Agreement for all contributions, internal and external.

Therefore, we will hold off on merging this PR for a little bit.

@NiklasGustafsson
Copy link
Contributor

An update: this repo will be moved to the .NET Foundation organization, after which we will have the CLA and can accept the contribution.

@NiklasGustafsson
Copy link
Contributor

@fwaris, if you resubmit this PR, GitHub will take you through the signing of a CLA, and we can accept your PR. Also, the text for the copyright header has changed from 'Microsoft Corp.' to '.NET Foundation.'

@dnfadmin
Copy link

dnfadmin commented Sep 1, 2021

CLA assistant check
All CLA requirements met.

Copy link
Contributor

@NiklasGustafsson NiklasGustafsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor changes required before the PR will be accepted.

@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should now say 'Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.'

/// <returns>attn_output, attn_ouput_weights</returns>

public Tuple<Tensor,Tensor> forward(Tensor query, Tensor key, Tensor value, Tensor? key_padding_mask = null, bool need_weights = true, Tensor? attn_mask = null)
//const NNModule module, const Tensor query, const Tensor key, const Tensor value, const Tensor key_padding_mask, const bool need_weights, const Tensor attn_mask, Tensor res1, Tensor res2 )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's commented code,

public Tuple<Tensor,Tensor> forward(Tensor query, Tensor key, Tensor value, Tensor? key_padding_mask = null, bool need_weights = true, Tensor? attn_mask = null)
//const NNModule module, const Tensor query, const Tensor key, const Tensor value, const Tensor key_padding_mask, const bool need_weights, const Tensor attn_mask, Tensor res1, Tensor res2 )
{
//var res1 = IntPtr.Zero;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More commented code.

@@ -268,3 +268,4 @@ packages/
.ionide
*.bin
/*.png
/src/Native/out/build/x64-Debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd. Was this necessary?

@fwaris
Copy link
Contributor Author

fwaris commented Sep 1, 2021

i signed the CLA but not sure if the PR is still valid. If not, create a new one later this week.

@NiklasGustafsson
Copy link
Contributor

i signed the CLA but not sure if the PR is still valid. If not, create a new one later this week.

I closed and resubmitted the PR, so now the checks are running. I had some review comments.

@NiklasGustafsson
Copy link
Contributor

@fwaris, I would like to get this in and then do a NuGet release with it, so I'm going to merge the PR and address the requested changes in another PR.

@NiklasGustafsson NiklasGustafsson merged commit 6c0a8f2 into dotnet:main Sep 2, 2021
@fwaris
Copy link
Contributor Author

fwaris commented Sep 2, 2021

thanks. I will rebase and make the request changes as another PR

@NiklasGustafsson
Copy link
Contributor

thanks. I will rebase and make the request changes as another PR

@fwaris, there's no need. I already made the changes and they are in main now.

@GeorgeS2019
Copy link

@fwaris We are attempting to visualize MultiheadAttention module within the transformer architecture. You input is appreciated.

:-) I still looking forwards to your tests and samples on MultiheadAttention :-)

@fwaris
Copy link
Contributor Author

fwaris commented Sep 30, 2021

day job is keeping me very busy but I intend to complete the TGN model port (see above) port in the next few weeks. Only a few modules left to port.

@GeorgeS2019
Copy link

@fwaris take your time. Thx for valuable voluntary contribution :-)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants