Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Batched equivariant maps basis expansion (?) #76

Open
Danfoa opened this issue Aug 24, 2023 · 4 comments
Open

Batched equivariant maps basis expansion (?) #76

Danfoa opened this issue Aug 24, 2023 · 4 comments

Comments

@Danfoa
Copy link
Contributor

Danfoa commented Aug 24, 2023

Hi @Gabri95,

Do you see any easy way to enable a batched basis expansion of equivariant maps? What do I mean...

The process of construction of a linear equivariant map T from an array of w=weights (of the same dimension as the dimension of the basis of T) seems to be tailored for a single weight vector and a single resultant equivariant linear map. This is perfectly suitable for building the basis of linear layers.

However, it is not suitable for parametrically building several equivariant linear maps from a batched collection of weights (batch, dim(w)), resulting in batch equivariant linear operators. I tried my best to understand and devise a way to do this, but with the current implementation, it seems rather tricky.

When using the EMLP library, this was possible by finding the nullspace projector matrix Q [nxn, basis_dim], which we can use to project several weight vectors T =reshape(Q w) to their corresponding equiv linear matrices. This process had an immense memory complexity (because of the nxn: n being the dimension of the T, assuming squared T). I understand your approach is elegantly avoiding this memory complexity problem. Do you think of a way of making a batch version of your basis expansion?

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 24, 2023

hi @Danfoa

The single-block basis expansion and sampler classes could be used for that.
That's actually what I also internally do in the BlockBasisSampler class for example.

The external interface of the library (via the conv layers) does not directly support this, though.
Could you maybe provide a more detailed example of what you'd like to do, so I can suggest something more concrete or try to write some example of code?

For instance, do you need to compute a number of convolution kernels for an RdConv or do you want to run multiple RdPointConv in parallel? Or are you only interested in LinearLayers?

Best,
Gabriele

@Danfoa
Copy link
Contributor Author

Danfoa commented Aug 24, 2023

This sounds amazing thanks for the help!.

Let me describe my application case.

TLDR: I want to construct multiple equivariant linear maps T of shape [nxn]. We know that the basis of T is of dimension d. I don't want to learn this map, instead:

  1. I want to learn a function T(.): X -> R^d that parameterizes the linear maps T(x) \in R^(nxn), as a function of their input x of shape (batch, |x|).
  2. The output of the network of shape (batch, d) will be used to parameterize batch distinct equivariant maps, resulting in (batch, n, n).
  3. Then, I would like to apply the linear maps to each of the input vectors.

More details: I am learning equivariant dynamical systems with transition Operators. The nice thing about this approach is that if you find the appropriate non-linear change of coordinates x = f(z), the dynamics of your system become linear dx/dt = T(x)x | T(x) \in R^(nxn), instead of the potentially non-linear dynamics of z. Here, think of z as the state of your dynamical system (e.g., position and momentum) and x as a new "observable" state (e.g., a set of relevant functions of x, such as energy, polynomials, etc.). For equivariant systems, T(x) needs to be an equivariant linear map. And here is where I need to learn the function T(.): X -> R^d. Here d is the dimension of the space of endomorphisms X->X. Which is why your basis expansion has become so useful to me.

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 25, 2023

Hi @Danfoa

That sounds like a really cool application!

So, if you know in advance the size of batch, the simplest strategy you can use now is to generate a linear map of shape batch*n x n, and then reshaping it into batch, n, n.
You can just use a BlockBasisExpansion for expanding these weights.

I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights.
I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?

@Danfoa
Copy link
Contributor Author

Danfoa commented Aug 25, 2023

I can certainly try @Gabri95,

So, if you know in advance the size of batch, the simplest strategy you can use now is to generate a linear map of shape batch*n x n, and then reshaping it into batch, n, n.
You can just use a BlockBasisExpansion for expanding these weights.

I know the batch dimension, but I am a bit insecure about how to interact with the BlockBasisExpansion. The code is a bit hard to digest without investing a large amount of time on it. Any hints?

I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights.
I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?

I will give it a try. I think I already see the problem. It should not be difficult.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants