diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ab694f..0d1d80a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,16 @@ ### Added - Partial support of `torch.jit.script` and `torch.compile` +- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `IrrepsArray`). ### Changed - `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. +- `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties. + +## Removed + +- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. ### Fixed diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index e93f04b..0fc4b28 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -36,6 +36,7 @@ IrrepsLayout, mul_ir, ir_mul, + IrrepsAndLayout, get_layout_scope, assume, NumpyIrrepsArray, @@ -71,6 +72,7 @@ "IrrepsLayout", "mul_ir", "ir_mul", + "IrrepsAndLayout", "get_layout_scope", "assume", "NumpyIrrepsArray", diff --git a/cuequivariance/cuequivariance/descriptors/__init__.py b/cuequivariance/cuequivariance/descriptors/__init__.py index af6e6e3..f18c438 100644 --- a/cuequivariance/cuequivariance/descriptors/__init__.py +++ b/cuequivariance/cuequivariance/descriptors/__init__.py @@ -30,9 +30,7 @@ yxy_rotation, inversion, ) -from .escn import escn_tp, escn_tp_compact from .spherical_harmonics_ import sympy_spherical_harmonics, spherical_harmonics -from .gatr import gatr_linear, gatr_geometric_product, gatr_outer_product __all__ = [ "transpose", @@ -49,11 +47,6 @@ "yx_rotation", "yxy_rotation", "inversion", - "escn_tp", - "escn_tp_compact", "sympy_spherical_harmonics", "spherical_harmonics", - "gatr_linear", - "gatr_geometric_product", - "gatr_outer_product", ] diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/descriptors/irreps_tp.py index 77a25da..a6d7fc4 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/descriptors/irreps_tp.py @@ -75,8 +75,12 @@ def fully_connected_tensor_product( d = d.normalize_paths_for_operand(-1) return cue.EquivariantTensorProduct( d, - [irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1, cue.ir_mul), + cue.IrrepsAndLayout(irreps2, cue.ir_mul), + cue.IrrepsAndLayout(irreps3, cue.ir_mul), + ], ) @@ -131,8 +135,11 @@ def full_tensor_product( d = d.normalize_paths_for_operand(-1) return cue.EquivariantTensorProduct( d, - [irreps1, irreps2, irreps3], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps1, cue.ir_mul), + cue.IrrepsAndLayout(irreps2, cue.ir_mul), + cue.IrrepsAndLayout(irreps3, cue.ir_mul), + ], ) @@ -193,8 +200,12 @@ def channelwise_tensor_product( d = d.normalize_paths_for_operand(-1) return cue.EquivariantTensorProduct( d, - [irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps1, cue.ir_mul), + cue.IrrepsAndLayout(irreps2, cue.ir_mul), + cue.IrrepsAndLayout(irreps3, cue.ir_mul), + ], ) @@ -273,7 +284,12 @@ def elementwise_tensor_product( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) return cue.EquivariantTensorProduct( - d, [irreps1, irreps2, irreps3], layout=cue.ir_mul + d, + [ + cue.IrrepsAndLayout(irreps1, cue.ir_mul), + cue.IrrepsAndLayout(irreps2, cue.ir_mul), + cue.IrrepsAndLayout(irreps3, cue.ir_mul), + ], ) @@ -308,6 +324,9 @@ def linear( return cue.EquivariantTensorProduct( d, - [irreps_in.new_scalars(d.operands[0].size), irreps_in, irreps_out], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(irreps_out, cue.ir_mul), + ], ) diff --git a/cuequivariance/cuequivariance/descriptors/rotations.py b/cuequivariance/cuequivariance/descriptors/rotations.py index 8bf4418..2257c1b 100644 --- a/cuequivariance/cuequivariance/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/descriptors/rotations.py @@ -42,7 +42,13 @@ def fixed_axis_angle_rotation( ) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul) + return cue.EquivariantTensorProduct( + d, + [ + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + ], + ) def yxy_rotation( @@ -70,13 +76,12 @@ def yxy_rotation( return cue.EquivariantTensorProduct( cbaio, [ - irreps.new_scalars(cbaio.operands[0].size), - irreps.new_scalars(cbaio.operands[1].size), - irreps.new_scalars(cbaio.operands[2].size), - irreps, - irreps, + cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[1].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[2].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - layout=cue.ir_mul, ) @@ -95,12 +100,11 @@ def xy_rotation( return cue.EquivariantTensorProduct( cbio, [ - irreps.new_scalars(cbio.operands[0].size), - irreps.new_scalars(cbio.operands[1].size), - irreps, - irreps, + cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - layout=cue.ir_mul, ) @@ -119,12 +123,11 @@ def yx_rotation( return cue.EquivariantTensorProduct( cbio, [ - irreps.new_scalars(cbio.operands[0].size), - irreps.new_scalars(cbio.operands[1].size), - irreps, - irreps, + cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - layout=cue.ir_mul, ) @@ -188,7 +191,12 @@ def y_rotation( d = d.flatten_coefficient_modes() return cue.EquivariantTensorProduct( - d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul + d, + [ + cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + ], ) @@ -213,7 +221,12 @@ def x_rotation( d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1)) return cue.EquivariantTensorProduct( - d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul + d, + [ + cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + ], ) @@ -228,4 +241,10 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct: assert np.allclose(H @ H, np.eye(ir.dim), atol=1e-6) d.add_path(None, None, c=H, dims={"u": mul}) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul) + return cue.EquivariantTensorProduct( + d, + [ + cue.IrrepsAndLayout(irreps, cue.ir_mul), + cue.IrrepsAndLayout(irreps, cue.ir_mul), + ], + ) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 40d27ea..af7beb6 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -54,7 +54,13 @@ def spherical_harmonics( indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) - return cue.EquivariantTensorProduct([d], [ir_vec, ir], layout=layout) + return cue.EquivariantTensorProduct( + [d], + [ + cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul), + cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul), + ], + ) def poly_degrees_to_path_indices(degrees: tuple[int, ...]) -> tuple[int, ...]: diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index 0f558b4..48f3fab 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -104,6 +104,9 @@ def symmetric_contraction( d = d.append_modes_to_all_operands("u", {"u": mul}) return cue.EquivariantTensorProduct( [d], - [irreps_in.new_scalars(d.operands[0].size), mul * irreps_in, mul * irreps_out], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), + ], ) diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/descriptors/transposition.py index 56cc0f5..ae67116 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/descriptors/transposition.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance.equivariant_tensor_product import Operand def transpose( @@ -22,12 +21,16 @@ def transpose( """Transpose the irreps layout of a tensor.""" d = cue.SegmentedTensorProduct( operands=[ - cue.Operand(subscripts="ui" if source == cue.mul_ir else "iu"), - cue.Operand(subscripts="ui" if target == cue.mul_ir else "iu"), + cue.segmented_tensor_product.Operand( + subscripts="ui" if source == cue.mul_ir else "iu" + ), + cue.segmented_tensor_product.Operand( + subscripts="ui" if target == cue.mul_ir else "iu" + ), ] ) for mul, ir in irreps: d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim}) return cue.EquivariantTensorProduct( - d, [Operand(irreps, source), Operand(irreps, target)] + d, [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)] ) diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index 0c2d745..4d2cdd4 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -22,12 +22,6 @@ from cuequivariance import segmented_tensor_product as stp -@dataclasses.dataclass(init=True, frozen=True) -class Operand: - irreps: cue.Irreps - layout: cue.IrrepsLayout - - @dataclasses.dataclass(init=False, frozen=True) class EquivariantTensorProduct: """ @@ -59,29 +53,19 @@ class EquivariantTensorProduct: .. rubric:: Methods """ - operands: tuple[Operand, ...] + operands: tuple[cue.Rep, ...] ds: list[stp.SegmentedTensorProduct] def __init__( self, d: Union[stp.SegmentedTensorProduct, Sequence[stp.SegmentedTensorProduct]], - operands: list[Union[cue.Irreps, Operand]], - layout: Optional[cue.IrrepsLayout] = None, + operands: list[cue.Rep], ): - operands = tuple( - ( - ope - if isinstance(ope, Operand) - else Operand( - irreps=cue.Irreps(ope), layout=layout or cue.get_layout_scope() - ) - ) - for ope in operands - ) + operands = tuple(operands) if isinstance(d, stp.SegmentedTensorProduct): assert len(operands) == d.num_operands for oid in range(d.num_operands): - assert operands[oid].irreps.dim == d.operands[oid].size + assert operands[oid].dim == d.operands[oid].size ds = [d] else: ds = list(d) @@ -94,20 +78,20 @@ def __init__( if not (i < d.num_operands - 1): continue - if operands[i].irreps.dim != d.operands[i].size: + if operands[i].dim != d.operands[i].size: raise ValueError( - f"Input {i} size mismatch: {operands[i].irreps} vs {d.operands[i]}" + f"Input {i} size mismatch: {operands[i]} vs {d.operands[i]}" ) # the repeated input operand is the same assert len(operands) >= 2 for d_ope in d.operands[nin - 1 : -1]: - if operands[-2].irreps.dim != d_ope.size: + if operands[-2].dim != d_ope.size: raise ValueError( - f"Last input size mismatch: {operands[-2].irreps} vs {d_ope}" + f"Last input size mismatch: {operands[-2]} vs {d_ope}" ) - if operands[-1].irreps.dim != d.operands[-1].size: + if operands[-1].dim != d.operands[-1].size: raise ValueError( - f"Output size mismatch: {operands[-1].irreps} vs {d.operands[-1]}" + f"Output size mismatch: {operands[-1]} vs {d.operands[-1]}" ) # all non-repeated inputs have the same operands @@ -140,11 +124,11 @@ def num_inputs(self) -> int: return self.num_operands - 1 @property - def inputs(self) -> tuple[Operand, ...]: + def inputs(self) -> tuple[cue.Rep, ...]: return self.operands[:-1] @property - def output(self) -> Operand: + def output(self) -> cue.Rep: return self.operands[-1] def _degrees(self, i: int) -> set[int]: @@ -278,6 +262,8 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: new_subscripts = [] for oid, (operand, layout) in enumerate(zip(operands, layouts_)): + assert isinstance(operand, cue.IrrepsAndLayout) + subscripts = d.subscripts.operands[oid] if operand.layout == layout: new_subscripts.append(subscripts) @@ -307,7 +293,7 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: return EquivariantTensorProduct( [f(d) for d in self.ds], [ - Operand(ope.irreps, layout) + cue.IrrepsAndLayout(ope.irreps, layout) for ope, layout in zip(self.operands, layouts) ], ) @@ -324,7 +310,7 @@ def memory_cost( if isinstance(itemsize, int): itemsize = (itemsize,) * self.num_operands return sum( - bs * operand.irreps.dim * iz + bs * operand.dim * iz for iz, bs, operand in zip(itemsize, batch_sizes, self.operands) ) @@ -413,27 +399,23 @@ def stack( """Stack multiple equivariant tensor products.""" assert len(es) > 0 num_operands = es[0].num_operands - layouts = [ope.layout for ope in es[0].operands] + assert all(e.num_operands == num_operands for e in es) - assert all( - e.operands[oid].layout == layouts[oid] - for e in es - for oid in range(num_operands) - ) assert len(stacked) == num_operands new_operands = [] for oid in range(num_operands): if stacked[oid]: - new_operands.append( - Operand( - irreps=cue.concatenate([e.operands[oid].irreps for e in es]), - layout=layouts[oid], + if not all( + isinstance(e.operands[oid], cue.IrrepsAndLayout) for e in es + ): + raise NotImplementedError( + f"Stacking of {type(es[0].operands[oid])} is not implemented" ) - ) + new_operands.append(cue.concatenate([e.operands[oid] for e in es])) else: ope = es[0].operands[oid] - assert all(e.operands[oid].irreps == ope.irreps for e in es) + assert all(e.operands[oid] == ope for e in es) new_operands.append(ope) new_ds: dict[int, stp.SegmentedTensorProduct] = {} diff --git a/cuequivariance/cuequivariance/descriptors/escn.py b/cuequivariance/cuequivariance/experimental/escn.py similarity index 92% rename from cuequivariance/cuequivariance/descriptors/escn.py rename to cuequivariance/cuequivariance/experimental/escn.py index ad27b6c..c08b0f6 100644 --- a/cuequivariance/cuequivariance/descriptors/escn.py +++ b/cuequivariance/cuequivariance/experimental/escn.py @@ -107,8 +107,11 @@ def pr(mul_ir: cue.MulIrrep) -> bool: d = d.flatten_coefficient_modes() return cue.EquivariantTensorProduct( d, - [irreps_in.new_scalars(d.operands[0].size), irreps_in, irreps_out], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(irreps_out, cue.ir_mul), + ], ) @@ -175,4 +178,13 @@ def escn_tp_compact( d.add_path(i, l_max_in - m, l_max_out + m, c=-1.0) d = d.normalize_paths_for_operand(2) - return d + return d # TODO: return an EquivariantTensorProduct using SphericalSignal + + +class SphericalSignal(cue.Rep): + def __init__(self, mul: int, l_max: int, m_max: int): + self.mul = mul + self.l_max = l_max + self.m_max = m_max + + # TODO diff --git a/cuequivariance/cuequivariance/descriptors/gatr.py b/cuequivariance/cuequivariance/experimental/gatr.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/gatr.py rename to cuequivariance/cuequivariance/experimental/gatr.py diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index ebfc5c7..78aee0e 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -18,8 +18,8 @@ import numpy as np import cuequivariance as cue -from cuequivariance import descriptors import cuequivariance.segmented_tensor_product as stp +from cuequivariance import descriptors from cuequivariance.misc.linalg import round_to_sqrt_rational, triu_array @@ -44,7 +44,7 @@ def symmetric_contraction( cuex.equivariant_tensor_product(e, w, cuex.randn(jax.random.key(1), e.inputs[1])) """ - assert max(degrees) > 0 + assert min(degrees) > 0 e1 = cue.EquivariantTensorProduct.stack( [ cue.EquivariantTensorProduct.stack( @@ -147,8 +147,11 @@ def _symmetric_contraction( d = d.append_modes_to_all_operands("u", {"u": mul}) return cue.EquivariantTensorProduct( [d], - [irreps_in.new_scalars(d.operands[0].size), mul * irreps_in, mul * irreps_out], - layout=cue.ir_mul, + [ + cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), + cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), + ], ) @@ -182,6 +185,8 @@ def U_matrix_real( def _wigner_nj( irreps_in: cue.Irreps, degree: int, filter_ir_mid: Optional[frozenset[cue.Irrep]] ) -> list[tuple[cue.Irrep, np.ndarray]]: + assert degree > 0 + if degree == 1: ret = [] e = np.eye(irreps_in.dim) diff --git a/cuequivariance/cuequivariance/irreps_array/__init__.py b/cuequivariance/cuequivariance/irreps_array/__init__.py index 65edb78..8733c1d 100644 --- a/cuequivariance/cuequivariance/irreps_array/__init__.py +++ b/cuequivariance/cuequivariance/irreps_array/__init__.py @@ -15,6 +15,7 @@ from .context_irrep_class import get_irrep_scope from .irreps import MulIrrep, Irreps from .irreps_layout import IrrepsLayout, mul_ir, ir_mul +from .irreps_and_layout import IrrepsAndLayout from .context_layout import get_layout_scope from .context_decorator import assume @@ -33,6 +34,7 @@ "IrrepsLayout", "mul_ir", "ir_mul", + "IrrepsAndLayout", "get_layout_scope", "assume", "NumpyIrrepsArray", diff --git a/cuequivariance/cuequivariance/irreps_array/context_decorator.py b/cuequivariance/cuequivariance/irreps_array/context_decorator.py index 3989905..3ee8e90 100644 --- a/cuequivariance/cuequivariance/irreps_array/context_decorator.py +++ b/cuequivariance/cuequivariance/irreps_array/context_decorator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import wraps -from typing import Optional, Union, Type +from typing import Optional, Type, Union import cuequivariance as cue import cuequivariance.irreps_array as irreps_array @@ -28,19 +28,25 @@ class assume: - """Context manager / decorator to assume the irrep class and layout for a block of code. + """ + ``assume`` is a context manager or decorator to assume the irrep class and layout for a block of code or a function. Examples: - ``` - with cue.assume(irrep_class="SU2", layout=cue.mul_ir): - ... - ``` + As a context manager: + + >>> with cue.assume(cue.SO3, cue.mul_ir): + ... rep = cue.IrrepsAndLayout("2x1") + >>> rep.irreps + 2x1 + >>> rep.layout + (mul,irrep) + + As a decorator: - ``` - @cue.assume(irrep_class="SU2", layout=cue.mul_ir) - def my_function(): - ... - ``` + >>> @cue.assume(cue.SO3, cue.mul_ir) + ... def foo(): + ... return cue.IrrepsAndLayout("2x1") + >>> assert foo() == rep """ def __init__( diff --git a/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py b/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py new file mode 100644 index 0000000..b92b778 --- /dev/null +++ b/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np + +import cuequivariance as cue + + +@dataclass(init=False, frozen=True) +class IrrepsAndLayout(cue.Rep): + r""" + A group representation (:class:`Rep`) made from the combination of :class:`Irreps` and :class:`IrrepsLayout` into a single object. + + This class inherits from :class:`Rep`:: + + Rep <--- Base class for all representations + ├── Irrep <--- Base class for all irreducible representations + ├── SU2 + ├── SO3 + ├── O3 + ├── IrrepsAndLayout <--- This class + + IrrepsLayout <--- Enum class with two values: mul_ir and ir_mul + + Irreps <--- Collection of Irrep with multiplicities + + Args: + irreps (Irreps or str): Irreducible representations and their multiplicities. + layout (optional, IrrepsLayout): The data layout (``mul_ir`` or ``ir_mul``). + + Examples: + Let's create rotations matrices for a 2x1 representation of SO(3) using two different layouts: + + >>> angles = np.array([np.pi, 0, 0]) + + Here we use the ``ir_mul`` layout: + + >>> with cue.assume("SO3", cue.ir_mul): + ... rep = cue.IrrepsAndLayout("2x1") + >>> R_ir_mul = rep.exp_map(angles, np.array([])) + + Here we use the ``mul_ir`` layout: + + >>> with cue.assume("SO3", cue.mul_ir): + ... rep = cue.IrrepsAndLayout("2x1") + >>> R_mul_ir = rep.exp_map(angles, np.array([])) + + Let's see the difference between the two layouts: + + >>> R_ir_mul.round(1) + 0.0 + array([[ 1., 0., 0., 0., 0., 0.], + [ 0., 1., 0., 0., 0., 0.], + [ 0., 0., -1., 0., 0., 0.], + [ 0., 0., 0., -1., 0., 0.], + [ 0., 0., 0., 0., -1., 0.], + [ 0., 0., 0., 0., 0., -1.]]) + + >>> R_mul_ir.round(1) + 0.0 + array([[ 1., 0., 0., 0., 0., 0.], + [ 0., -1., 0., 0., 0., 0.], + [ 0., 0., -1., 0., 0., 0.], + [ 0., 0., 0., 1., 0., 0.], + [ 0., 0., 0., 0., -1., 0.], + [ 0., 0., 0., 0., 0., -1.]]) + """ + + irreps: cue.Irreps = field() + layout: cue.IrrepsLayout = field() + + def __init__( + self, irreps: cue.Irreps | str, layout: cue.IrrepsLayout | None = None + ): + irreps = cue.Irreps(irreps) + if layout is None: + layout = cue.get_layout_scope() + + object.__setattr__(self, "irreps", irreps) + object.__setattr__(self, "layout", layout) + + def __repr__(self): + return f"{self.irreps}" + + def _dim(self) -> int: + return self.irreps.dim + + def algebra(self) -> np.ndarray: + return self.irreps.irrep_class.algebra() + + def continuous_generators(self) -> np.ndarray: + if self.layout == cue.mul_ir: + return block_diag( + [np.kron(np.eye(mul), ir.X) for mul, ir in self.irreps], (self.lie_dim,) + ) + if self.layout == cue.ir_mul: + return block_diag( + [np.kron(ir.X, np.eye(mul)) for mul, ir in self.irreps], (self.lie_dim,) + ) + + def discrete_generators(self) -> np.ndarray: + num_H = self.irreps.irrep_class.trivial().H.shape[0] + + if self.layout == cue.mul_ir: + return block_diag( + [np.kron(np.eye(mul), ir.H) for mul, ir in self.irreps], (num_H,) + ) + if self.layout == cue.ir_mul: + return block_diag( + [np.kron(ir.H, np.eye(mul)) for mul, ir in self.irreps], (num_H,) + ) + + def trivial(self) -> cue.Rep: + ir = self.irreps.irrep_class.trivial() + return IrrepsAndLayout( + cue.Irreps(self.irreps.irrep_class, [ir]), + self.layout, + ) + + def is_scalar(self) -> bool: + return self.irreps.is_scalar() + + def __eq__(self, other: cue.Rep) -> bool: + if isinstance(other, IrrepsAndLayout): + return self.irreps == other.irreps and ( + self.irreps.layout_insensitive() or self.layout == other.layout + ) + return cue.Rep.__eq__(self, other) + + +def block_diag(entries: list[np.ndarray], leading_shape: tuple[int, ...]) -> np.ndarray: + if len(entries) == 0: + return np.zeros(leading_shape + (0, 0)) + + A = entries[0] + assert A.shape[:-2] == leading_shape, (A.shape, leading_shape) + + if len(entries) == 1: + return A + + B = entries[1] + assert B.shape[:-2] == leading_shape + + i, m = A.shape[-2:] + j, n = B.shape[-2:] + + C = np.block( + [[A, np.zeros(leading_shape + (i, n))], [np.zeros(leading_shape + (j, m)), B]] + ) + return block_diag([C] + entries[2:], leading_shape) diff --git a/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py b/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py index 9341dc1..2c51c62 100644 --- a/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py +++ b/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py @@ -252,11 +252,21 @@ def from_segments( def concatenate( - arrays: Union[list[cue.Irreps], list[NumpyIrrepsArray]], -) -> NumpyIrrepsArray: + arrays: Union[ + list[cue.Irreps], + list[Union[cue.IrrepsAndLayout]], + list[NumpyIrrepsArray], + ], +) -> Union[cue.Irreps, cue.IrrepsAndLayout, NumpyIrrepsArray]: if len(arrays) == 0: raise ValueError("Expected at least one input") + if all(isinstance(array, cue.IrrepsAndLayout) for array in arrays): + assert len({x.layout for x in arrays}) == 1 + return cue.IrrepsAndLayout( + concatenate([x.irreps for x in arrays]), arrays[0].layout + ) + if all(isinstance(array, cue.Irreps) for array in arrays): return sum(arrays, cue.Irreps(arrays[0].irrep_class, [])) diff --git a/cuequivariance/cuequivariance/misc/linalg.py b/cuequivariance/cuequivariance/misc/linalg.py index 65d3ac9..aaa86be 100644 --- a/cuequivariance/cuequivariance/misc/linalg.py +++ b/cuequivariance/cuequivariance/misc/linalg.py @@ -398,26 +398,26 @@ def sparsify_matrix( for i0, i1 in graph.edges: result = sparsify_rows(x[i0], x[i1], round_fn) - if isinstance(result, DisjointRows): - pass - elif isinstance(result, ReplaceRow): - a0, a1, row, which = result.a0, result.a1, result.row, result.which - hope = True - which = i0 if which == 0 else i1 - x[which] = row - q[which] = a0 * q[i0] + a1 * q[i1] - - next_graph.add_edge(i0, i1) - for i in graph.neighbors(i0): - if i != i1: - next_graph.add_edge(i, i1) - for i in graph.neighbors(i1): - if i != i0: - next_graph.add_edge(i, i0) - elif isinstance(result, AlreadySparse): - next_graph.add_edge(i0, i1) - else: - raise ValueError(f"Unknown result type: {result}") + match result: + case DisjointRows(): + pass + case ReplaceRow(a0=a0, a1=a1, row=row, which=which): + hope = True + which = i0 if which == 0 else i1 + x[which] = row + q[which] = a0 * q[i0] + a1 * q[i1] + + next_graph.add_edge(i0, i1) + for i in graph.neighbors(i0): + if i != i1: + next_graph.add_edge(i, i1) + for i in graph.neighbors(i1): + if i != i0: + next_graph.add_edge(i, i0) + case AlreadySparse(): + next_graph.add_edge(i0, i1) + case _: + raise ValueError(f"Unknown result type: {result}") iterations += 1 graph = next_graph diff --git a/cuequivariance/cuequivariance/representation/rep.py b/cuequivariance/cuequivariance/representation/rep.py index 7fca347..1b186c7 100644 --- a/cuequivariance/cuequivariance/representation/rep.py +++ b/cuequivariance/cuequivariance/representation/rep.py @@ -43,6 +43,9 @@ def dim(self) -> int: Returns: int: The dimension of the representation. """ + return self._dim() + + def _dim(self) -> int: X = self.continuous_generators() d = X.shape[1] return d @@ -156,5 +159,13 @@ def is_trivial(self) -> bool: """Check if the representation is trivial (scalar of dimension 1)""" return self.dim == 1 and self.is_scalar() + def __eq__(self, other: Rep) -> bool: + return ( + self.dim == other.dim + and np.allclose(self.A, other.A) + and np.allclose(self.H, other.H) + and np.allclose(self.X, other.X) + ) + def __repr__(self) -> str: return f"Rep(dim={self.dim}, lie_dim={self.lie_dim}, len(H)={len(self.H)})" diff --git a/cuequivariance/tests/experimental/escn_test.py b/cuequivariance/tests/experimental/escn_test.py new file mode 100644 index 0000000..2b770f4 --- /dev/null +++ b/cuequivariance/tests/experimental/escn_test.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cuequivariance as cue +from cuequivariance.experimental.escn import escn_tp, escn_tp_compact + + +def test_escn(): + escn_tp( + cue.Irreps("O3", "32x0e + 32x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), + m_max=2, + ) + + +def test_escn_compact(): + escn_tp_compact( + cue.Irreps("SO3", "32x0 + 32x1"), + cue.Irreps("SO3", "32x0 + 32x1"), + m_max=2, + ) diff --git a/cuequivariance/tests/experimental/gatr_test.py b/cuequivariance/tests/experimental/gatr_test.py new file mode 100644 index 0000000..cef872d --- /dev/null +++ b/cuequivariance/tests/experimental/gatr_test.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cuequivariance.experimental.gatr import ( + gatr_geometric_product, + gatr_linear, + gatr_outer_product, +) + + +def test_geometric_product(): + gatr_geometric_product() + + +def test_outer_product(): + gatr_outer_product() + + +def test_linear(): + gatr_linear(32, 32) diff --git a/cuequivariance/tests/experimental/mace_test.py b/cuequivariance/tests/experimental/mace_test.py new file mode 100644 index 0000000..fec5d86 --- /dev/null +++ b/cuequivariance/tests/experimental/mace_test.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cuequivariance as cue +from cuequivariance.experimental.mace import symmetric_contraction + + +def test_symmetric_contraction(): + e, p = symmetric_contraction( + cue.Irreps("O3", "32x0e + 32x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), + degrees=[1, 2, 3], + ) diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 4c9aa03..0aa886e 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -19,12 +19,9 @@ ) -from .irreps_array.jax_irreps_array import ( - IrrepsArray, - from_segments, - vmap, -) -from .irreps_array.utils import concatenate, randn, as_irreps_array +from .rep_array.jax_rep_array import RepArray, from_segments, IrrepsArray +from .rep_array.vmap import vmap +from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan from .primitives.tensor_product import tensor_product from .primitives.symmetric_tensor_product import symmetric_tensor_product @@ -41,12 +38,14 @@ from cuequivariance_jax import flax_linen __all__ = [ - "IrrepsArray", + "RepArray", "from_segments", - "as_irreps_array", + "IrrepsArray", "vmap", "concatenate", "randn", + "as_irreps_array", + "clebsch_gordan", "tensor_product", "symmetric_tensor_product", "equivariant_tensor_product", diff --git a/cuequivariance_jax/cuequivariance_jax/experimental/utils.py b/cuequivariance_jax/cuequivariance_jax/experimental/utils.py index 92b2c69..f26cd5e 100644 --- a/cuequivariance_jax/cuequivariance_jax/experimental/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/experimental/utils.py @@ -132,8 +132,9 @@ def smooth_bump(x: jax.Array) -> jax.Array: def gather( - i: jax.Array, x: cuex.IrrepsArray, n: int, indices_are_sorted: bool = False -) -> cuex.IrrepsArray: + i: jax.Array, x: cuex.RepArray, n: int, indices_are_sorted: bool = False +) -> cuex.RepArray: + assert 0 not in x.reps y = jnp.zeros((n,) + x.shape[1:], dtype=x.dtype) y = y.at[i].add(x.array, indices_are_sorted=indices_are_sorted) - return cuex.IrrepsArray(x.irreps(), y, x.layout) + return cuex.RepArray(x.reps, y) diff --git a/cuequivariance_jax/cuequivariance_jax/flax_linen/layer_norm.py b/cuequivariance_jax/cuequivariance_jax/flax_linen/layer_norm.py index 54bea42..8c168b8 100644 --- a/cuequivariance_jax/cuequivariance_jax/flax_linen/layer_norm.py +++ b/cuequivariance_jax/cuequivariance_jax/flax_linen/layer_norm.py @@ -36,8 +36,8 @@ class LayerNorm(nn.Module): epsilon: float = 0.01 @nn.compact - def __call__(self, input: cuex.IrrepsArray) -> cuex.IrrepsArray: - assert input.is_simple() + def __call__(self, input: cuex.RepArray) -> cuex.RepArray: + assert input.is_irreps_array() def rms(v: jax.Array) -> jax.Array: # v [..., ir, mul] or [..., mul, ir] @@ -54,8 +54,8 @@ def rms(v: jax.Array) -> jax.Array: return rmsn return cuex.from_segments( - input.irreps(), - [x / (rms(x) + self.epsilon) for x in input.segments()], + input.irreps, + [x / (rms(x) + self.epsilon) for x in input.segments], input.shape, input.layout, input.dtype, diff --git a/cuequivariance_jax/cuequivariance_jax/flax_linen/linear.py b/cuequivariance_jax/cuequivariance_jax/flax_linen/linear.py index 69f9d9c..9f822e9 100644 --- a/cuequivariance_jax/cuequivariance_jax/flax_linen/linear.py +++ b/cuequivariance_jax/cuequivariance_jax/flax_linen/linear.py @@ -54,28 +54,25 @@ class Linear(nn.Module): @nn.compact def __call__( - self, input: cuex.IrrepsArray, algorithm: str = "sliced" - ) -> cuex.IrrepsArray: - if not isinstance(input, cuex.IrrepsArray): - raise ValueError(f"input must be of type IrrepsArray, got {type(input)}") - - assert input.is_simple() + self, input: cuex.RepArray, algorithm: str = "sliced" + ) -> cuex.RepArray: + assert input.is_irreps_array() irreps_out = cue.Irreps(self.irreps_out) layout_out = cue.IrrepsLayout.as_layout(self.layout) - assert_same_group(input.irreps(), irreps_out) + assert_same_group(input.irreps, irreps_out) if not self.force: - irreps_out = irreps_out.filter(keep=input.irreps()) + irreps_out = irreps_out.filter(keep=input.irreps) - e = descriptors.linear(input.irreps(), irreps_out) + e = descriptors.linear(input.irreps, irreps_out) e = e.change_layout([cue.ir_mul, input.layout, layout_out]) # Flattening mode i does slow down the computation a bit if algorithm != "sliced": e = e.flatten_modes("i") - w = self.param("w", self.kernel_init, (e.operands[0].irreps.dim,), input.dtype) + w = self.param("w", self.kernel_init, (e.operands[0].dim,), input.dtype) return cuex.equivariant_tensor_product( e, w, input, precision=jax.lax.Precision.HIGH, algorithm=algorithm diff --git a/cuequivariance_jax/cuequivariance_jax/irreps_array/jax_irreps_array.py b/cuequivariance_jax/cuequivariance_jax/irreps_array/jax_irreps_array.py deleted file mode 100644 index 50e7ac2..0000000 --- a/cuequivariance_jax/cuequivariance_jax/irreps_array/jax_irreps_array.py +++ /dev/null @@ -1,736 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Callable, Sequence - -import jax -import jax.numpy as jnp -import numpy as np - -import cuequivariance as cue -import cuequivariance_jax as cuex # noqa: F401 - - -def _check_args( - dirreps: Any, layout: Any, ndim: int | None -) -> tuple[dict[int, cue.Irreps], cue.IrrepsLayout]: - if isinstance(dirreps, (cue.Irreps, cue.Irrep, str)): - dirreps = {-1: cue.Irreps(dirreps)} - - if not isinstance(dirreps, dict): - raise ValueError( - f"IrrepsArray: dirreps must be a dict of int -> Irreps, not {dirreps}" - ) - - dirreps = {k: cue.Irreps(v) for k, v in dirreps.items()} - - if not all( - isinstance(k, int) and isinstance(v, cue.Irreps) for k, v in dirreps.items() - ): - raise ValueError( - f"IrrepsArray: dirreps must be a dict of int -> Irreps, not {dirreps}" - ) - - layout = cue.IrrepsLayout.as_layout(layout) - - if ndim is not None: - dirreps = {k + ndim if k < 0 else k: v for k, v in dirreps.items()} - - if any(k < 0 for k in dirreps.keys()): - raise ValueError( - f"IrrepsArray: dirreps keys must be non-negative, not {dirreps}" - ) - - return dirreps, layout - - -@dataclass(frozen=True, init=False, repr=False) -class IrrepsArray: - """ - Wrapper around a jax array with a dict of Irreps for the non-trivial axes. - - .. rubric:: Creation - - >>> cuex.IrrepsArray( - ... {-1: cue.Irreps("SO3", "2x0")}, jnp.array([1.0, 2.0]), cue.ir_mul - ... ) - {0: 2x0} [1. 2.] - - If you don't specify the axis it will default to the last axis: - - >>> cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul - ... ) - {0: 2x0} [1. 2.] - - You can use a default group and layout: - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0])) - {0: 2x0} [1. 2.] - - .. rubric:: Arithmetic - - Basic arithmetic operations are supported, as long as they are equivariant: - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0])) - ... y = cuex.IrrepsArray("2x0", jnp.array([3.0, 4.0])) - ... x + y - {0: 2x0} [4. 6.] - - >>> 3.0 * x - {0: 2x0} [3. 6.] - - .. rubric:: Attributes - - Attributes: - dirreps: Irreps for the non-trivial axes, see also :func:`irreps() ` below. - array: JAX array - layout: Data layout - shape: Shape of the array - ndim: Number of dimensions of the array - dtype: Data type of the array - - .. rubric:: Methods - """ - - layout: cue.IrrepsLayout = field() - dirreps: dict[int, cue.Irreps] = field() - array: jax.Array = field() - - def __init__( - self, - irreps: cue.Irreps | str | dict[int, cue.Irreps | str], - array: jax.Array, - layout: cue.IrrepsLayout | None = None, - ): - dirreps, layout = _check_args(irreps, layout, getattr(array, "ndim", None)) - - if ( - hasattr(array, "shape") - and isinstance(array.shape, tuple) - and len(array.shape) > 0 - ): - for axis, irreps_ in dirreps.items(): - if len(array.shape) <= axis or array.shape[axis] != irreps_.dim: - raise ValueError( - f"IrrepsArray: Array shape {array.shape} incompatible with irreps {irreps_}.\n" - "If you are trying to use jax.vmap, use cuex.vmap instead." - ) - - object.__setattr__(self, "dirreps", dirreps) - object.__setattr__(self, "array", array) - object.__setattr__(self, "layout", layout) - - @property - def shape(self) -> tuple[int, ...]: - return self.array.shape - - @property - def ndim(self) -> int: - return self.array.ndim - - @property - def dtype(self) -> jax.numpy.dtype: - return self.array.dtype - - def is_simple(self) -> bool: - """Return True if the IrrepsArray has only the last axis non-trivial. - - Examples: - - >>> cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul - ... ).is_simple() - True - """ - if len(self.dirreps) != 1: - return False - axis = next(iter(self.dirreps.keys())) - return axis == self.ndim - 1 - - def irreps(self, axis: int = -1) -> cue.Irreps: - """Return the Irreps for a given axis. - - Examples: - - >>> cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul - ... ).irreps() - 2x0 - """ - axis = axis if axis >= 0 else axis + self.ndim - if axis not in self.dirreps: - raise ValueError(f"No Irreps for axis {axis}") - return self.dirreps[axis] - - def __repr__(self): - r = str(self.array) - if "\n" in r: - return f"{self.dirreps}\n{r}" - return f"{self.dirreps} {r}" - - def __getitem__(self, key: Any) -> IrrepsArray: - # self[None] - if key is None: - return IrrepsArray( - {k + 1: irreps for k, irreps in self.dirreps.items()}, - self.array[None], - self.layout, - ) - - # self[jnp.array([0, 1, 2])] - assert isinstance(key, jax.Array) - assert 0 not in self.dirreps - return IrrepsArray( - {k + key.ndim - 1: irreps for k, irreps in self.dirreps.items()}, - self.array[key], - self.layout, - ) - - def slice_by_mul(self, axis: int = -1) -> _MulIndexSliceHelper: - r"""Return the slice with respect to the multiplicities. - - Examples: - - >>> x = cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0 + 1"), - ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul - ... ) - >>> x.slice_by_mul()[1:4] - {0: 0+1} [2. 0. 0. 0.] - """ - return _MulIndexSliceHelper(self, axis) - - def __neg__(self) -> IrrepsArray: - return IrrepsArray(self.dirreps, -self.array, self.layout) - - def __add__(self, other: IrrepsArray | int | float) -> IrrepsArray: - if isinstance(other, (int, float)): - assert other == 0 - return self - - if self.dirreps != other.dirreps: - raise ValueError( - f"Cannot add IrrepsArrays with different dirreps: {self.dirreps} != {other.dirreps}" - ) - if self.layout != other.layout: - raise ValueError( - f"Cannot add IrrepsArrays with different layouts: {self.layout} != {other.layout}" - ) - return IrrepsArray(self.dirreps, self.array + other.array, self.layout) - - def __radd__(self, other: IrrepsArray) -> IrrepsArray: - return self + other - - def __sub__(self, other: IrrepsArray | int | float) -> IrrepsArray: - return self + (-other) - - def __rsub__(self, other: IrrepsArray | int | float) -> IrrepsArray: - return -self + other - - def __mul__(self, other: jax.Array) -> IrrepsArray: - other = jnp.asarray(other) - other = jnp.expand_dims(other, tuple(range(self.ndim - other.ndim))) - for axis, irreps in self.dirreps.items(): - assert other.shape[axis] == 1 - return IrrepsArray(self.dirreps, self.array * other, self.layout) - - def __truediv__(self, other: jax.Array) -> IrrepsArray: - other = jnp.asarray(other) - other = jnp.expand_dims(other, tuple(range(self.ndim - other.ndim))) - for axis, irreps in self.dirreps.items(): - assert other.shape[axis] == 1 - return IrrepsArray(self.dirreps, self.array / other, self.layout) - - def __rmul__(self, other: jax.Array) -> IrrepsArray: - return self * other - - def filter( - self, - *, - keep: str | Sequence[cue.Irrep] | Callable[[cue.MulIrrep], bool] | None = None, - drop: str | Sequence[cue.Irrep] | Callable[[cue.MulIrrep], bool] | None = None, - mask: Sequence[bool] | None = None, - axis: int = -1, - ) -> IrrepsArray: - """Filter the irreps. - - Args: - keep: Irreps to keep. - drop: Irreps to drop. - mask: Boolean mask for segments to keep. - axis: Axis to filter. - - Examples: - - >>> x = cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0 + 1"), - ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul - ... ) - >>> x.filter(keep="0") - {0: 2x0} [1. 2.] - >>> x.filter(drop="0") - {0: 1} [0. 0. 0.] - >>> x.filter(mask=[True, False]) - {0: 2x0} [1. 2.] - """ - if mask is None: - mask = self.irreps(axis).filter_mask(keep=keep, drop=drop) - - if all(mask): - return self - - if not any(mask): - shape = list(self.shape) - shape[axis] = 0 - return IrrepsArray( - self.dirreps | {axis: cue.Irreps(self.irreps(axis).irrep_class, "")}, - jnp.zeros(shape, dtype=self.dtype), - self.layout, - ) - - return IrrepsArray( - self.dirreps | {axis: self.irreps(axis).filter(mask=mask)}, - jnp.concatenate( - [ - take_slice(self.array, s, axis) - for s, m in zip(self.irreps(axis).slices(), mask) - if m - ], - axis=axis, - ), - self.layout, - ) - - def sort(self, axis: int = -1) -> IrrepsArray: - """Sort the irreps. - - Examples: - - >>> x = cuex.IrrepsArray( - ... cue.Irreps("SO3", "1 + 2x0"), - ... jnp.array([1.0, 1.0, 1.0, 2.0, 3.0]), cue.ir_mul - ... ) - >>> x.sort() - {0: 2x0+1} [2. 3. 1. 1. 1.] - """ - if axis < 0: - axis += self.ndim - - irreps = self.irreps(axis) - r = irreps.sort() - - segments = self.segments(axis) - return from_segments( - self.dirreps | {axis: r.irreps}, - [segments[i] for i in r.inv], - self.shape, - self.layout, - self.dtype, - axis, - ) - - def simplify(self, axis: int = -1) -> IrrepsArray: - if axis < 0: - axis += self.ndim - - dirreps = self.dirreps | {axis: self.irreps(axis).simplify()} - - if self.layout == cue.mul_ir: - return IrrepsArray(dirreps, self.array, self.layout) - - assert self.is_simple() - segments = [] - last_ir = None - for x, (mul, ir) in zip(self.segments(), self.irreps()): - if last_ir is None or last_ir != ir: - segments.append(x) - last_ir = ir - else: - segments[-1] = jnp.concatenate([segments[-1], x], axis=-1) - - return from_segments( - self.irreps().simplify(), - segments, - self.shape, - cue.ir_mul, - self.dtype, - ) - - def regroup(self, axis: int = -1) -> IrrepsArray: - """Clean up the irreps. - - Examples: - - >>> x = cuex.IrrepsArray( - ... cue.Irreps("SO3", "0 + 1 + 0"), jnp.array([0., 1., 2., 3., -1.]), - ... cue.ir_mul - ... ) - >>> x.regroup() - {0: 2x0+1} [ 0. -1. 1. 2. 3.] - """ - return self.sort(axis).simplify(axis) - - def segments(self, axis: int = -1) -> list[jax.Array]: - """Split the array into segments. - - Examples: - - >>> x = cuex.IrrepsArray( - ... cue.Irreps("SO3", "2x0 + 1"), jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), - ... cue.ir_mul - ... ) - >>> x.segments() - [Array(...), Array(...)] - - Note: - - See also :func:`cuex.from_segments `. - """ - irreps = self.irreps(axis) - return [ - take_slice(self.array, s, axis).reshape( - expanded_shape(self.shape, mul_ir, axis, self.layout) - ) - for s, mul_ir in zip(irreps.slices(), irreps) - ] - - def change_layout(self, layout: cue.IrrepsLayout | None = None) -> IrrepsArray: - assert self.is_simple() - - if layout is None: - layout = cue.get_layout_scope() - - if self.layout == layout: - return self - - return from_segments( - self.dirreps, - [jnp.moveaxis(x, -2, -1) for x in self.segments()], - self.shape, - layout, - self.dtype, - ) - - def move_axis_to_mul(self, axis: int) -> IrrepsArray: - assert self.is_simple() - assert self.layout == cue.ir_mul - if axis < 0: - axis += self.ndim - - mul = self.shape[axis] - array = jnp.moveaxis(self.array, axis, -1) - array = jnp.reshape(array, array.shape[:-2] + (mul * self.irreps().dim,)) - return IrrepsArray(mul * self.irreps(), array, cue.ir_mul) - - def transform(self, v: jax.Array) -> IrrepsArray: - assert self.is_simple() - - def f(segment: jax.Array, mul: int, ir: cue.Irrep) -> jax.Array: - X = ir.X - assert np.allclose(X, -X.conj().T) # TODO: support other types of X - - X = jnp.asarray(X, dtype=v.dtype) - iX = 1j * jnp.einsum("a,aij->ij", v, X) - m, V = jnp.linalg.eigh(iX) - # np.testing.assert_allclose(V @ np.diag(m) @ V.T.conj(), iX, atol=1e-10) - - phase = jnp.exp(-1j * m) - R = V @ jnp.diag(phase) @ V.T.conj() - R = R.real - - match self.layout: - case cue.mul_ir: - return jnp.einsum("ij,...uj->...ui", R, segment) - case cue.ir_mul: - return jnp.einsum("ij,...ju->...iu", R, segment) - - return from_segments( - self.dirreps, - [f(x, mul, ir) for x, (mul, ir) in zip(self.segments(), self.irreps())], - self.shape, - self.layout, - self.dtype, - ) - - -def expanded_shape( - shape: tuple[int, ...], mul_ir: cue.MulIrrep, axis: int, layout: cue.IrrepsLayout -) -> tuple[int, ...]: - if axis < 0: - axis += len(shape) - return shape[:axis] + layout.shape(mul_ir) + shape[axis + 1 :] - - -def from_segments( - dirreps: cue.Irreps | str | dict[int, cue.Irreps | str], - segments: Sequence[jax.Array], - shape: tuple[int, ...], - layout: cue.IrrepsLayout | None = None, - dtype: jnp.dtype | None = None, - axis: int = -1, -) -> IrrepsArray: - """Construct an :class:`cuex.IrrepsArrays ` from a list of segments. - - Args: - dirreps: final Irreps. - segments: list of segments. - shape: shape of the final array. - layout: layout of the final array. - dtype: data type - axis: axis to concatenate the segments. - - Returns: - IrrepsArray: IrrepsArray. - - Examples: - - >>> cuex.from_segments( - ... cue.Irreps("SO3", "2x0 + 1"), - ... [jnp.array([[1.0], [2.0]]), jnp.array([[0.0], [0.0], [0.0]])], - ... (-1,), cue.ir_mul) - {0: 2x0+1} [1. 2. 0. 0. 0.] - - Note: - - See also :func:`cuex.IrrepsArray.segments `. - """ - ndim = len(shape) - dirreps, layout = _check_args(dirreps, layout, ndim) - if axis < 0: - axis += ndim - - shape = list(shape) - for iaxis, irreps in dirreps.items(): - shape[iaxis] = irreps.dim - - if not all(x.ndim == len(shape) + 1 for x in segments): - raise ValueError( - "from_segments: segments must have ndim equal to len(shape) + 1" - ) - - if len(segments) != len(dirreps[axis]): - raise ValueError( - f"from_segments: the number of segments {len(segments)} must match the number of irreps {len(dirreps[axis])}" - ) - - if dtype is not None: - segments = [segment.astype(dtype) for segment in segments] - segments = [ - segment.reshape( - segment.shape[:axis] + (mul * ir.dim,) + segment.shape[axis + 2 :] - ) - for (mul, ir), segment in zip(dirreps[axis], segments) - ] - - if len(segments) > 0: - array = jnp.concatenate(segments, axis=axis) - else: - array = jnp.zeros(shape, dtype=dtype) - - return IrrepsArray(dirreps, array, layout) - - -def take_slice(x: jax.Array, s: slice, axis: int) -> jax.Array: - slices = [slice(None)] * x.ndim - slices[axis] = s - return x[tuple(slices)] - - -def encode_irreps_array(x: IrrepsArray) -> tuple: - data = (x.array,) - static = (x.layout, x.dirreps) - return data, static - - -def decode_irreps_array(static, data) -> IrrepsArray: - layout, dirreps = static - (array,) = data - return IrrepsArray(dirreps, array, layout) - - -jax.tree_util.register_pytree_node( - IrrepsArray, encode_irreps_array, decode_irreps_array -) - - -def remove_axis(dirreps: dict[int, cue.Irreps], axis: int): - assert axis >= 0 - if axis in dirreps: - raise ValueError( - f"Cannot vmap over an Irreps axis. {axis} has Irreps {dirreps[axis]}." - ) - return { - a - 1 if a > axis else a: irreps for a, irreps in dirreps.items() if a != axis - } - - -def add_axis(dirreps: dict[int, cue.Irreps], axis: int): - return {a + 1 if a >= axis else a: irreps for a, irreps in dirreps.items()} - - -def vmap( - fun: Callable[..., Any], - in_axes: int | tuple[int, ...] = 0, - out_axes: int = 0, -) -> Callable[..., Any]: - """ - Like jax.vmap, but for IrrepsArray. - - Args: - fun: Callable[..., Any]: Function to vectorize. Can take `IrrepsArray` as input and output. - in_axes: int | tuple[int, ...]: Axes to vectorize over. - out_axes: int: Axes to vectorize over. - - Returns: - Callable[..., Any]: Vectorized function. - """ - - def inside_fun(*args, **kwargs): - args, kwargs = jax.tree.map( - lambda x: ( - IrrepsArray(x.dirreps, x.array, x.layout) - if isinstance(x, _wrapper) - else x - ), - (args, kwargs), - is_leaf=lambda x: isinstance(x, _wrapper), - ) - out = fun(*args, **kwargs) - return jax.tree.map( - lambda x: ( - _wrapper(x.layout, add_axis(x.dirreps, out_axes), x.array) - if isinstance(x, IrrepsArray) - else x - ), - out, - is_leaf=lambda x: isinstance(x, IrrepsArray), - ) - - def outside_fun(*args, **kwargs): - if isinstance(in_axes, int): - in_axes_ = (in_axes,) * len(args) - else: - in_axes_ = in_axes - - args = [ - jax.tree.map( - lambda x: ( - _wrapper( - x.layout, - remove_axis(x.dirreps, axis if axis >= 0 else axis + x.ndim), - x.array, - ) - if isinstance(x, IrrepsArray) - else x - ), - arg, - is_leaf=lambda x: isinstance(x, IrrepsArray), - ) - for axis, arg in zip(in_axes_, args) - ] - kwargs = jax.tree.map( - lambda x: ( - _wrapper(x.layout, remove_axis(x.dirreps, 0), x.array) - if isinstance(x, IrrepsArray) - else x - ), - kwargs, - is_leaf=lambda x: isinstance(x, IrrepsArray), - ) - out = jax.vmap(inside_fun, in_axes, out_axes)(*args, **kwargs) - return jax.tree.map( - lambda x: ( - IrrepsArray(x.dirreps, x.array, x.layout) - if isinstance(x, _wrapper) - else x - ), - out, - is_leaf=lambda x: isinstance(x, _wrapper), - ) - - return outside_fun - - -@dataclass(frozen=True) -class _wrapper: - layout: cue.IrrepsLayout = field() - dirreps: dict[int, cue.Irreps] = field() - array: jax.Array = field() - - -jax.tree_util.register_pytree_node( - _wrapper, - lambda x: ((x.array,), (x.layout, x.dirreps)), - lambda static, data: _wrapper(static[0], static[1], data[0]), -) - - -class _MulIndexSliceHelper: - irreps_array: IrrepsArray - axis: int - - def __init__(self, irreps_array: IrrepsArray, axis: int): - self.irreps_array = irreps_array - self.axis = axis if axis >= 0 else axis + irreps_array.ndim - - def __getitem__(self, index: slice) -> IrrepsArray: - if not isinstance(index, slice): - raise IndexError( - "IrrepsArray.slice_by_mul only supports one slices (like IrrepsArray.slice_by_mul[2:4])." - ) - - input_irreps = self.irreps_array.irreps(self.axis) - start, stop, stride = index.indices(input_irreps.num_irreps) - if stride != 1: - raise NotImplementedError( - "IrrepsArray.slice_by_mul does not support strides." - ) - - mul_axis = { - cue.mul_ir: self.axis, - cue.ir_mul: self.axis + 1, - }[self.irreps_array.layout] - - output_irreps = [] - segments = [] - i = 0 - for (mul, ir), x in zip(input_irreps, self.irreps_array.segments(self.axis)): - if start <= i and i + mul <= stop: - output_irreps.append((mul, ir)) - segments.append(x) - elif start < i + mul and i < stop: - output_irreps.append((min(stop, i + mul) - max(start, i), ir)) - segments.append( - take_slice( - x, slice(max(start, i) - i, min(stop, i + mul) - i), mul_axis - ) - ) - - i += mul - - return from_segments( - self.irreps_array.dirreps - | {self.axis: cue.Irreps(input_irreps.irrep_class, output_irreps)}, - segments, - self.irreps_array.shape, - self.irreps_array.layout, - self.irreps_array.dtype, - self.axis, - ) diff --git a/cuequivariance_jax/cuequivariance_jax/irreps_array/utils.py b/cuequivariance_jax/cuequivariance_jax/irreps_array/utils.py deleted file mode 100644 index e832502..0000000 --- a/cuequivariance_jax/cuequivariance_jax/irreps_array/utils.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Sequence - -import jax -import jax.numpy as jnp - -import cuequivariance as cue -import cuequivariance_jax as cuex -from cuequivariance.equivariant_tensor_product import Operand -from cuequivariance.irreps_array.misc_ui import assert_same_group - - -def concatenate(arrays: list[cuex.IrrepsArray], axis: int = -1) -> cuex.IrrepsArray: - """Concatenate a list of :class:`cuex.IrrepsArray ` - - Args: - arrays (list of IrrepsArray): List of arrays to concatenate. - axis (int, optional): Axis along which to concatenate. Defaults to -1. - - Example: - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.IrrepsArray("3x0", jnp.array([1.0, 2.0, 3.0])) - ... y = cuex.IrrepsArray("1x1", jnp.array([0.0, 0.0, 0.0])) - >>> cuex.concatenate([x, y]) - {0: 3x0+1} [1. 2. 3. 0. 0. 0.] - """ - if len(arrays) == 0: - raise ValueError( - "Must provide at least one array to concatenate" - ) # pragma: no cover - if not all(a.layout == arrays[0].layout for a in arrays): - raise ValueError("All arrays must have the same layout") # pragma: no cover - if not all(a.ndim == arrays[0].ndim for a in arrays): - raise ValueError( - "All arrays must have the same number of dimensions" - ) # pragma: no cover - assert_same_group(*[a.irreps(axis) for a in arrays]) - - if axis < 0: - axis += arrays[0].ndim - - irreps = sum( - (a.irreps(axis) for a in arrays), cue.Irreps(arrays[0].irreps(axis), []) - ) - list_dirreps = [a.dirreps | {axis: irreps} for a in arrays] - if not all(d == list_dirreps[0] for d in list_dirreps): - raise ValueError("All arrays must have the same dirreps") # pragma: no cover - - return cuex.IrrepsArray( - list_dirreps[0], - jnp.concatenate([a.array for a in arrays], axis=axis), - arrays[0].layout, - ) - - -def randn( - key: jax.Array, - irreps: cue.Irreps | Operand, - leading_shape: tuple[int, ...] = (), - layout: cue.IrrepsLayout | None = None, - dtype: jnp.dtype | None = None, -) -> cuex.IrrepsArray: - r"""Generate a random :class:`cuex.IrrepsArrays `. - - Args: - key (jax.Array): Random key. - irreps (Irreps): Irreps of the array. - leading_shape (tuple[int, ...], optional): Leading shape of the array. Defaults to (). - layout (IrrepsLayout): Layout of the array. - dtype (jnp.dtype): Data type of the array. - - Returns: - IrrepsArray: Random IrrepsArray. - - Example: - - >>> key = jax.random.key(0) - >>> irreps = cue.Irreps("O3", "2x1o") - >>> cuex.randn(key, irreps, (), cue.ir_mul) - {0: 2x1o} [...] - """ - if isinstance(irreps, Operand): - assert layout is None - irreps, layout = irreps.irreps, irreps.layout - - irreps = cue.Irreps(irreps) - leading_shape = tuple(leading_shape) - - return cuex.IrrepsArray( - irreps, - jax.random.normal(key, leading_shape + (irreps.dim,), dtype=dtype), - layout, - ) - - -def as_irreps_array( - input: Any, - layout: cue.IrrepsLayout | None = None, - axis: int | Sequence[int] = -1, - like: cuex.IrrepsArray | None = None, -) -> cuex.IrrepsArray: - """Converts input to an IrrepsArray. Arrays are assumed to be scalars. - - Examples: - - >>> with cue.assume(cue.O3): - ... cuex.as_irreps_array([1.0], layout=cue.ir_mul) - {0: 0e} [1.] - """ - # We need first to define axes and layout - if like is not None: - assert layout is None - assert axis == -1 - layout = like.layout - axes = { - axis - like.ndim: irreps.irrep_class.trivial() - for axis, irreps in like.dirreps.items() - } - else: - if isinstance(input, cuex.IrrepsArray): - axes = { - axis: input.irreps(axis).irrep_class.trivial() - for axis in (axis if isinstance(axis, Sequence) else [axis]) - } - else: - ir = cue.get_irrep_scope().trivial() - axes = { - axis: ir for axis in (axis if isinstance(axis, Sequence) else [axis]) - } - if layout is None: - if isinstance(input, cuex.IrrepsArray): - layout = input.layout - else: - layout = cue.get_layout_scope() - del like, axis - - if isinstance(input, cuex.IrrepsArray): - if input.layout != layout: - raise ValueError( - f"as_irreps_array: layout mismatch {input.layout} != {layout}" - ) - for axis, ir in axes.items(): - if input.irreps(axis).irrep_class is not type(ir): - raise ValueError( - f"as_irreps_array: irrep mismatch {input.irreps(axis).irrep_class} != {type(ir)}" - ) - return input - - input: jax.Array = jnp.asarray(input) - # if max(axes.keys()) >= input.ndim: - # raise ValueError( - # f"as_irreps_array: input has {input.ndim} dimensions, but axes are {axes.keys()}" - # ) - - dirreps = { - axis: cue.Irreps(type(ir), [(input.shape[axis], ir)]) - for axis, ir in axes.items() - } - return cuex.IrrepsArray(dirreps, input, layout) diff --git a/cuequivariance_jax/cuequivariance_jax/operations/activation.py b/cuequivariance_jax/cuequivariance_jax/operations/activation.py index 29eb304..a171ce9 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/activation.py +++ b/cuequivariance_jax/cuequivariance_jax/operations/activation.py @@ -87,29 +87,27 @@ def function_parity(phi: ActFn) -> int: def scalar_activation( - input: cuex.IrrepsArray, + input: cuex.RepArray, acts: ActFn | list[ActFn | None] | dict[cue.Irrep, ActFn], *, normalize_act: bool = True, -) -> cuex.IrrepsArray: - r"""Apply activation functions to the scalars of an `IrrepsArray`. +) -> cuex.RepArray: + r"""Apply activation functions to the scalars of an `RepArray`. The activation functions are by default normalized. """ input = cuex.as_irreps_array(input) - assert isinstance(input, cuex.IrrepsArray) - assert input.is_simple() if isinstance(acts, dict): - acts = [acts.get(ir, None) for mul, ir in input.irreps()] + acts = [acts.get(ir, None) for mul, ir in input.irreps] if callable(acts): - acts = [acts] * len(input.irreps()) + acts = [acts] * len(input.irreps) - assert len(input.irreps()) == len(acts), (input.irreps(), acts) + assert len(input.irreps) == len(acts), (input.irreps, acts) segments = [] irreps_out = [] - for (mul, ir), x, act in zip(input.irreps(), input.segments(), acts): + for (mul, ir), x, act in zip(input.irreps, input.segments, acts): mul: int ir: cue.Irrep x: jax.Array @@ -144,7 +142,7 @@ def scalar_activation( irreps_out.append((mul, ir)) segments.append(x) - irreps_out = cue.Irreps(input.irreps(), irreps_out) + irreps_out = cue.Irreps(input.irreps, irreps_out) return cuex.from_segments( irreps_out, segments, input.shape, input.layout, input.dtype ) diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py index f9fff23..1bb0510 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py @@ -23,19 +23,28 @@ def spherical_harmonics( ls: list[int], - vector: cuex.IrrepsArray, + vector: cuex.RepArray, normalize: bool = True, algorithm: str = "stacked", -) -> cuex.IrrepsArray: +) -> cuex.RepArray: + """Compute the spherical harmonics of a vector. + + Args: + ls (list of int): List of spherical harmonic degrees. + vector (RepArray): Input vector(s). + normalize (bool): Whether to normalize the vector before computing the spherical harmonics. + algorithm (str): Algorithm to use for the tensor product. See :class:`cuex.tensor_product ` for more information. + + Returns: + RepArray: Spherical harmonics of the vector. + """ ls = list(ls) - assert isinstance(vector, cuex.IrrepsArray) - assert vector.is_simple() - irreps = vector.irreps() + assert vector.is_irreps_array() + irreps = vector.irreps assert len(irreps) == 1 mul, ir = irreps[0] assert mul == 1 assert ir.dim == 3 - assert max(ls) > 0 assert min(ls) >= 0 if normalize: @@ -48,8 +57,8 @@ def spherical_harmonics( ) -def normalize(array: cuex.IrrepsArray) -> cuex.IrrepsArray: - assert array.is_simple() +def normalize(array: cuex.RepArray, epsilon: float = 0.0) -> cuex.RepArray: + assert array.is_irreps_array() match array.layout: case cue.ir_mul: @@ -59,13 +68,14 @@ def normalize(array: cuex.IrrepsArray) -> cuex.IrrepsArray: def f(x: jax.Array) -> jax.Array: sn = jnp.sum(jnp.conj(x) * x, axis=axis_ir, keepdims=True) - sn_safe = jnp.where(sn == 0.0, 1.0, sn) - rsn_safe = jnp.sqrt(sn_safe) - return x / rsn_safe + sn += epsilon + if epsilon == 0.0: + sn = jnp.where(sn == 0.0, 1.0, sn) + return x / jnp.sqrt(sn) return cuex.from_segments( - array.irreps(), - [f(x) for x in array.segments()], + array.irreps, + [f(x) for x in array.segments], array.shape, array.layout, array.dtype, @@ -75,9 +85,9 @@ def f(x: jax.Array) -> jax.Array: _normalize = normalize -def norm(array: cuex.IrrepsArray, *, squared: bool = False) -> cuex.IrrepsArray: - """Norm of IrrepsArray.""" - assert array.is_simple() +def norm(array: cuex.RepArray, *, squared: bool = False) -> cuex.RepArray: + """Norm of `RepArray`.""" + assert array.is_irreps_array() match array.layout: case cue.ir_mul: @@ -97,8 +107,8 @@ def f(x: jax.Array) -> jax.Array: return rsn return cuex.from_segments( - cue.Irreps(array.irreps(), [(mul, ir.trivial()) for mul, ir in array.irreps()]), - [f(x) for x in array.segments()], + cue.Irreps(array.irreps, [(mul, ir.trivial()) for mul, ir in array.irreps]), + [f(x) for x in array.segments], array.shape, array.layout, array.dtype, diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index 0554e58..b93ffdc 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -22,36 +22,52 @@ def equivariant_tensor_product( e: cue.EquivariantTensorProduct, - *inputs: cuex.IrrepsArray | jax.Array, + *inputs: cuex.RepArray | jax.Array, dtype_output: jnp.dtype | None = None, dtype_math: jnp.dtype | None = None, precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, algorithm: str = "sliced", use_custom_primitive: bool = True, use_custom_kernels: bool = False, -): +) -> cuex.RepArray: """Compute the equivariant tensor product of the input arrays. Args: e (EquivariantTensorProduct): The equivariant tensor product descriptor. - *inputs (IrrepsArray or jax.Array): The input arrays. + *inputs (RepArray or jax.Array): The input arrays. dtype_output (jnp.dtype, optional): The data type for the output array. Defaults to None. dtype_math (jnp.dtype, optional): The data type for computational operations. Defaults to None. - precision (jax.lax.Precision, optional): The precision for the computation. Defaults to jax.lax.Precision.HIGHEST. - algorithm (str, optional): One of "sliced", "stacked", "compact_stacked", "indexed_compact", "indexed_vmap", "indexed_for_loop". Defaults to "sliced". + precision (jax.lax.Precision, optional): The precision for the computation. Defaults to ``jax.lax.Precision.HIGHEST``. + algorithm (str, optional): One of "sliced", "stacked", "compact_stacked", "indexed_compact", "indexed_vmap", "indexed_for_loop". Defaults to "sliced". See :class:`cuex.tensor_product ` for more information. use_custom_primitive (bool, optional): Whether to use custom JVP rules. Defaults to True. use_custom_kernels (bool, optional): Whether to use custom kernels. Defaults to True. Returns: - IrrepsArray: The result of the equivariant tensor product. + RepArray: The result of the equivariant tensor product. Examples: + + Let's create a descriptor for the spherical harmonics of degree 0, 1, and 2. + >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) - >>> x = cuex.IrrepsArray(cue.Irreps("SO3", "1"), jnp.array([0.0, 1.0, 0.0]), cue.ir_mul) + >>> e + EquivariantTensorProduct((1)^(0..2) -> 0+1+2) + + We need some input data. + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) + >>> x + {0: 1} [0. 1. 0.] + + Now we can execute the equivariant tensor product. + >>> cuex.equivariant_tensor_product(e, x) {0: 0+1+2} [1. ... ] """ + assert e.num_inputs > 0 + if len(inputs) == 0: return lambda *inputs: equivariant_tensor_product( e, @@ -64,25 +80,24 @@ def equivariant_tensor_product( use_custom_kernels=use_custom_kernels, ) - if len(inputs) != len(e.inputs): + if len(inputs) != e.num_inputs: raise ValueError( - f"Unexpected number of inputs. Expected {len(e.inputs)}, got {len(inputs)}." + f"Unexpected number of inputs. Expected {e.num_inputs}, got {len(inputs)}." ) - for x, ope in zip(inputs, e.inputs): - if isinstance(x, cuex.IrrepsArray): - assert x.is_simple() - assert x.irreps() == ope.irreps - assert x.layout == ope.layout + for x, rep in zip(inputs, e.inputs): + if isinstance(x, cuex.RepArray): + assert x.rep(-1) == rep else: assert x.ndim >= 1 - assert x.shape[-1] == ope.irreps.dim - if not ope.irreps.is_scalar(): + assert x.shape[-1] == rep.dim + if not rep.is_scalar(): raise ValueError( - f"Inputs should be IrrepsArray unless the input is scalar. Got {type(x)} for {ope.irreps}." + f"Inputs should be RepArray unless the input is scalar. Got {type(x)} for {rep}." ) - inputs = [x.array if isinstance(x, cuex.IrrepsArray) else x for x in inputs] + inputs: list[jax.Array] = [getattr(x, "array", x) for x in inputs] + x = cuex.symmetric_tensor_product( e.ds, *inputs, @@ -94,4 +109,4 @@ def equivariant_tensor_product( use_custom_kernels=use_custom_kernels, ) - return cuex.IrrepsArray(e.output.irreps, x, e.output.layout) + return cuex.RepArray(e.output, x) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index f3ab07e..bc52ea6 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -46,39 +46,40 @@ def tensor_product( use_custom_kernels: bool = False, ) -> jax.Array: """ - Compute the last operand of a segmented tensor product. + Compute the last operand of a `SegmentedTensorProduct`. Args: - d (SegmentedTensorProduct): The segmented tensor product descriptor. - *inputs (jax.Array): The input arrays for the tensor product. - dtype_output (jnp.dtype, optional): The data type for the output. Defaults to None. + d (SegmentedTensorProduct): The descriptor of the operation. + *inputs (jax.Array): The input arrays for each operand except the last one. + dtype_output (jnp.dtype, optional): The data type for the output. dtype_math (jnp.dtype, optional): The data type for mathematical operations. - precision (jax.lax.Precision, optional): The precision for the computation. Defaults to jax.lax.Precision.HIGHEST. + precision (jax.lax.Precision, optional): The precision for the computation. Defaults to ``jax.lax.Precision.HIGHEST``. algorithm (str, optional): The algorithm to use for the computation. Defaults to "sliced". See table below for available algorithms. - use_custom_primitive (bool, optional): Whether to use custom JVP/transpose rules. Defaults to True. - use_custom_kernels (bool, optional): Whether to use custom kernels. Defaults to True. + use_custom_primitive (bool, optional): Whether to use custom JVP/transpose rules. + use_custom_kernels (bool, optional): Whether to use custom kernels. Returns: - jax.Array: The result of the tensor product computation. - - See Also: - :class:`cuequivariance.SegmentedTensorProduct` - - +---------------------+--------------------------+-----------------+----------------------------+ - | Algorithms | Needs Identical Segments | Compilation | Execution | - +=====================+==========================+=================+============================+ - |``sliced`` | No | Several minutes | It depends | - +---------------------+--------------------------+-----------------+----------------------------+ - |``stacked`` | Yes | Several minutes | It depends | - +---------------------+--------------------------+-----------------+----------------------------+ - |``compact_stacked`` | Yes | Few seconds | It depends | - +---------------------+--------------------------+-----------------+----------------------------+ - |``indexed_compact`` | Yes | Few seconds | It depends | - +---------------------+--------------------------+-----------------+----------------------------+ - |``indexed_vmap`` | Yes | Few seconds | Probably the second slowest| - +---------------------+--------------------------+-----------------+----------------------------+ - |``indexed_for_loop`` | Yes | Few seconds | Probably the slowest | - +---------------------+--------------------------+-----------------+----------------------------+ + jax.Array: The result of the tensor product. The last operand of the `SegmentedTensorProduct`. + + .. table:: Available algorithms for the tensor product + :align: center + :class: longtable + + +---------------------+--------------------------+------------------+----------------------------+ + | Algorithms | Needs Identical Segments | Compilation Time | Execution Time | + +=====================+==========================+==================+============================+ + |``sliced`` | No | Several minutes | It depends | + +---------------------+--------------------------+------------------+----------------------------+ + |``stacked`` | Yes | Several minutes | It depends | + +---------------------+--------------------------+------------------+----------------------------+ + |``compact_stacked`` | Yes | Few seconds | It depends | + +---------------------+--------------------------+------------------+----------------------------+ + |``indexed_compact`` | Yes | Few seconds | It depends | + +---------------------+--------------------------+------------------+----------------------------+ + |``indexed_vmap`` | Yes | Few seconds | Probably the second slowest| + +---------------------+--------------------------+------------------+----------------------------+ + |``indexed_for_loop`` | Yes | Few seconds | Probably the slowest | + +---------------------+--------------------------+------------------+----------------------------+ """ if isinstance(precision, str): precision = jax.lax.Precision[precision] diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py new file mode 100644 index 0000000..11a4b4f --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py @@ -0,0 +1,692 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Sequence + +import jax +import jax.numpy as jnp +import numpy as np + +import cuequivariance as cue +import cuequivariance_jax as cuex # noqa: F401 + + +@dataclass(frozen=True, init=False, repr=False) +class RepArray: + """ + A `jax.Array` decorated with a dict of `Rep` for the axes transforming under a group representation. + + Example: + + You can create a `RepArray` by specifying the `Rep` for each axis: + + >>> cuex.RepArray({0: cue.SO3(1), 1: cue.SO3(1)}, jnp.eye(3)) + {0: 1, 1: 1} + [[1. 0. 0.] + [0. 1. 0.] + [0. 0. 1.]] + + By default, arguments that are not `Rep` will be automatically converted into `IrrepsAndLayout`: + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray({0: "1", 1: "2"}, jnp.ones((3, 5))) + >>> x + {0: 1, 1: 2} + [[1. 1. 1. 1. 1.] + [1. 1. 1. 1. 1.] + [1. 1. 1. 1. 1.]] + >>> x.rep(0).irreps, x.rep(0).layout + (1, (irrep,mul)) + + .. rubric:: IrrepsArray + + An ``IrrepsArray`` is just a special case of a ``RepArray`` where the last axis is a `IrrepsAndLayout`: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "2x0"), jnp.zeros((3, 2)), cue.ir_mul + ... ) + >>> x + {1: 2x0} + [[0. 0.] + [0. 0.] + [0. 0.]] + + >>> x.is_irreps_array() + True + + You can use a default group and layout: + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... cuex.RepArray("2x0", jnp.array([1.0, 2.0])) + {0: 2x0} [1. 2.] + + .. rubric:: Arithmetic + + Basic arithmetic operations are supported, as long as they are equivariant: + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("2x0", jnp.array([1.0, 2.0])) + ... y = cuex.RepArray("2x0", jnp.array([3.0, 4.0])) + ... x + y + {0: 2x0} [4. 6.] + + >>> 3.0 * x + {0: 2x0} [3. 6.] + """ + + reps: dict[int, cue.Rep] = field() + array: jax.Array = field() + + def __init__( + self, + reps: dict[int, cue.Rep] + | cue.Rep + | cue.Irreps + | str + | dict[int, cue.Irreps] + | dict[int, str], + array: jax.Array, + layout: cue.IrrepsLayout | None = None, + ): + if not isinstance(reps, dict): + reps = {-1: reps} + + # Remaining cases: dict[int, cue.Rep] | dict[int, cue.Irreps] | dict[int, str] + + reps = { + axis: cue.Irreps(rep) if isinstance(rep, str) else rep + for axis, rep in reps.items() + } + + # Remaining cases: dict[int, cue.Rep] | dict[int, cue.Irreps] + + reps = { + axis: cue.IrrepsAndLayout(rep, layout) + if isinstance(rep, cue.Irreps) + else rep + for axis, rep in reps.items() + } + + del layout + assert isinstance(reps, dict) + assert all(isinstance(k, int) for k in reps) + assert all(isinstance(v, cue.Rep) for v in reps.values()) + + ndim = getattr(array, "ndim", None) + if ndim is not None: + reps = {k + ndim if k < 0 else k: v for k, v in reps.items()} + + assert all( + isinstance(k, int) and isinstance(v, cue.Rep) for k, v in reps.items() + ) + assert all(k >= 0 for k in reps) + + if ( + hasattr(array, "shape") + and isinstance(array.shape, tuple) + and len(array.shape) > 0 + ): + for axis, rep_ in reps.items(): + if len(array.shape) <= axis or array.shape[axis] != rep_.dim: + raise ValueError( + f"RepArray: Array shape {array.shape} incompatible with irreps {rep_}.\n" + "If you are trying to use jax.vmap, use cuex.vmap instead." + ) + + object.__setattr__(self, "reps", reps) + object.__setattr__(self, "array", array) + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the array.""" + return self.array.shape + + @property + def ndim(self) -> int: + """Number of dimensions of the array.""" + return self.array.ndim + + @property + def dtype(self) -> jax.numpy.dtype: + """Data type of the array.""" + return self.array.dtype + + def is_irreps_array(self) -> bool: + """Check if the RepArray is an ``IrrepsArray``. + + An ``IrrepsArray`` is a `RepArray` where the last axis is an `IrrepsAndLayout`. + """ + if len(self.reps) != 1: + return False + axis = next(iter(self.reps.keys())) + if axis != self.ndim - 1: + return False + rep = self.rep(-1) + return isinstance(rep, cue.IrrepsAndLayout) + + def rep(self, axis: int) -> cue.Rep: + """Return the Rep for a given axis.""" + axis = axis if axis >= 0 else axis + self.ndim + if axis not in self.reps: + raise ValueError(f"No Rep for axis {axis}") + return self.reps[axis] + + @property + def irreps(self) -> cue.Irreps: + """Return the `Irreps` of the ``IrrepsArray``. + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + return self.rep(-1).irreps + + @property + def layout(self) -> cue.IrrepsLayout: + """Return the layout of the ``IrrepsArray``. + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + return self.rep(-1).layout + + def __repr__(self): + r = str(self.array) + if "\n" in r: + return f"{self.reps}\n{r}" + return f"{self.reps} {r}" + + def __getitem__(self, key: Any) -> RepArray: + # self[None] + if key is None: + return RepArray( + {k + 1: rep for k, rep in self.reps.items()}, + self.array[None], + ) + + # self[jnp.array([0, 1, 2])] + assert isinstance(key, jax.Array) + assert 0 not in self.reps + return RepArray( + {k + key.ndim - 1: irreps for k, irreps in self.reps.items()}, + self.array[key], + ) + + @property + def slice_by_mul(self) -> _MulIndexSliceHelper: + r"""Return the slice with respect to the multiplicities. + + Examples: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "2x0 + 1"), + ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), + ... cue.ir_mul + ... ) + >>> x.slice_by_mul[1:4] + {0: 0+1} [2. 0. 0. 0.] + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + return _MulIndexSliceHelper(self) + + def __neg__(self) -> RepArray: + return RepArray(self.reps, -self.array) + + def __add__(self, other: RepArray | int | float) -> RepArray: + if isinstance(other, (int, float)): + assert other == 0 + return self + + if not isinstance(other, RepArray): + raise ValueError( + f"Try to add a RepArray with something that is not a RepArray: {other}" + ) + + if self.reps != other.reps: + raise ValueError( + f"Cannot add RepArray with different reps: {self.reps} != {other.reps}" + ) + + return RepArray(self.reps, self.array + other.array) + + def __radd__(self, other: RepArray) -> RepArray: + return self + other + + def __sub__(self, other: RepArray | int | float) -> RepArray: + return self + (-other) + + def __rsub__(self, other: RepArray | int | float) -> RepArray: + return -self + other + + def __mul__(self, other: jax.Array) -> RepArray: + other = jnp.asarray(other) + other = jnp.expand_dims(other, tuple(range(self.ndim - other.ndim))) + for axis, _ in self.reps.items(): + assert other.shape[axis] == 1 + return RepArray(self.reps, self.array * other) + + def __truediv__(self, other: jax.Array) -> RepArray: + other = jnp.asarray(other) + other = jnp.expand_dims(other, tuple(range(self.ndim - other.ndim))) + for axis, _ in self.reps.items(): + assert other.shape[axis] == 1 + return RepArray(self.reps, self.array / other) + + def __rmul__(self, other: jax.Array) -> RepArray: + return self * other + + def transform(self, v: jax.Array) -> RepArray: + """Transform the array according to the representation. + + Args: + v: Vector of angles. + + Examples: + + >>> x = cuex.RepArray( + ... {0: cue.SO3(1), 1: cue.SO3(1)}, jnp.ones((3, 3)) + ... ) + >>> x + {0: 1, 1: 1} + [[1. 1. 1.] + [1. 1. 1.] + [1. 1. 1.]] + >>> x.transform(jnp.array([np.pi, 0.0, 0.0])).array.round(1) + Array([[ 1., -1., -1.], + [-1., 1., 1.], + [-1., 1., 1.]]...) + """ + + def matrix(rep: cue.Rep) -> jax.Array: + X = rep.X + assert np.allclose( + X, -X.conj().transpose((0, 2, 1)) + ) # TODO: support other types of X + + X = jnp.asarray(X, dtype=v.dtype) + iX = 1j * jnp.einsum("a,aij->ij", v, X) + m, V = jnp.linalg.eigh(iX) + # np.testing.assert_allclose(V @ np.diag(m) @ V.T.conj(), iX, atol=1e-10) + + phase = jnp.exp(-1j * m) + R = V @ jnp.diag(phase) @ V.T.conj() + R = jnp.real(R) + return R + + if self.is_irreps_array(): + + def f(segment: jax.Array, ir: cue.Irrep) -> jax.Array: + R = matrix(ir) + match self.layout: + case cue.mul_ir: + return jnp.einsum("ij,...uj->...ui", R, segment) + case cue.ir_mul: + return jnp.einsum("ij,...ju->...iu", R, segment) + + return from_segments( + self.irreps, + [f(x, ir) for x, (_, ir) in zip(self.segments, self.irreps)], + self.shape, + self.layout, + self.dtype, + ) + + a = self.array + for axis, rep in self.reps.items(): + a = jnp.moveaxis(a, axis, 0) + R = matrix(rep) + a = jnp.einsum("ij,j...->i...", R, a) + a = jnp.moveaxis(a, 0, axis) + + return RepArray(self.reps, a) + + @property + def segments(self) -> list[jax.Array]: + """Split the array into segments. + + Examples: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "2x0 + 1"), jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), + ... cue.ir_mul + ... ) + >>> x.segments + [Array(...), Array(...)] + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + return [ + jnp.reshape(self.array[..., s], self.shape[:-1] + self.layout.shape(mulir)) + for s, mulir in zip(self.irreps.slices(), self.irreps) + ] + + def filter( + self, + *, + keep: str | Sequence[cue.Irrep] | Callable[[cue.MulIrrep], bool] | None = None, + drop: str | Sequence[cue.Irrep] | Callable[[cue.MulIrrep], bool] | None = None, + mask: Sequence[bool] | None = None, + ) -> RepArray: + """Filter the irreps. + + Args: + keep: Irreps to keep. + drop: Irreps to drop. + mask: Boolean mask for segments to keep. + axis: Axis to filter. + + Examples: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "2x0 + 1"), + ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul + ... ) + >>> x.filter(keep="0") + {0: 2x0} [1. 2.] + >>> x.filter(drop="0") + {0: 1} [0. 0. 0.] + >>> x.filter(mask=[True, False]) + {0: 2x0} [1. 2.] + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + + if mask is None: + mask = self.irreps.filter_mask(keep=keep, drop=drop) + + if all(mask): + return self + + if not any(mask): + shape = list(self.shape) + shape[-1] = 0 + return RepArray( + cue.Irreps(self.irreps.irrep_class, ""), + jnp.zeros(shape, dtype=self.dtype), + self.layout, + ) + + return RepArray( + self.irreps.filter(mask=mask), + jnp.concatenate( + [self.array[..., s] for s, m in zip(self.irreps.slices(), mask) if m], + axis=-1, + ), + self.layout, + ) + + def sort(self) -> RepArray: + """Sort the irreps. + + Examples: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "1 + 2x0"), + ... jnp.array([1.0, 1.0, 1.0, 2.0, 3.0]), cue.ir_mul + ... ) + >>> x.sort() + {0: 2x0+1} [2. 3. 1. 1. 1.] + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + + irreps = self.irreps + r = irreps.sort() + + segments = self.segments + return from_segments( + r.irreps, + [segments[i] for i in r.inv], + self.shape, + self.layout, + self.dtype, + ) + + def simplify(self) -> RepArray: + assert self.is_irreps_array() + + simplified_irreps = self.irreps.simplify() + + if self.layout == cue.mul_ir: + return RepArray(simplified_irreps, self.array, self.layout) + + segments = [] + last_ir = None + for x, (_mul, ir) in zip(self.segments, self.irreps): + if last_ir is None or last_ir != ir: + segments.append(x) + last_ir = ir + else: + segments[-1] = jnp.concatenate([segments[-1], x], axis=-1) + + return from_segments( + simplified_irreps, + segments, + self.shape, + cue.ir_mul, + self.dtype, + ) + + def regroup(self) -> RepArray: + """Clean up the irreps. + + Examples: + + >>> x = cuex.RepArray( + ... cue.Irreps("SO3", "0 + 1 + 0"), jnp.array([0., 1., 2., 3., -1.]), + ... cue.ir_mul + ... ) + >>> x.regroup() + {0: 2x0+1} [ 0. -1. 1. 2. 3.] + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + return self.sort().simplify() + + def change_layout(self, layout: cue.IrrepsLayout) -> RepArray: + """Change the layout of the ``IrrepsArray``. + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + if self.layout == layout: + return self + + return from_segments( + self.irreps, + [jnp.moveaxis(x, -2, -1) for x in self.segments], + self.shape, + layout, + self.dtype, + ) + + def move_axis_to_mul(self, axis: int) -> RepArray: + """Move an axis to the multiplicities. + + Note: + + This method is only available for ``IrrepsArray``. + See :func:`is_irreps_array `. + """ + assert self.is_irreps_array() + + if axis < 0: + axis += self.ndim + assert axis < self.ndim - 1 + + mul = self.shape[axis] + + match self.layout: + case cue.ir_mul: + array = jnp.moveaxis(self.array, axis, -1) + array = jnp.reshape(array, array.shape[:-2] + (self.irreps.dim * mul,)) + return RepArray(mul * self.irreps, array, cue.ir_mul) + case cue.mul_ir: + + def f(x): + x = jnp.moveaxis(x, axis, -3) + return jnp.reshape( + x, x.shape[:-3] + (mul * x.shape[-2], x.shape[-1]) + ) + + shape = list(self.shape) + del shape[axis] + shape[-1] = mul * shape[-1] + + return from_segments( + mul * self.irreps, + [f(x) for x in self.segments], + shape, + self.layout, + self.dtype, + ) + + +def encode_rep_array(x: RepArray) -> tuple: + data = (x.array,) + static = (x.reps,) + return data, static + + +def decode_rep_array(static, data) -> RepArray: + (reps,) = static + (array,) = data + return RepArray(reps, array) + + +jax.tree_util.register_pytree_node(RepArray, encode_rep_array, decode_rep_array) + +IrrepsArray = RepArray # TODO: do we deprecate IrrepsArray? + + +def from_segments( + irreps: cue.Irreps | str, + segments: Sequence[jax.Array], + shape: tuple[int, ...], + layout: cue.IrrepsLayout | None = None, + dtype: jnp.dtype | None = None, +) -> RepArray: + """Construct a `RepArray` from segments. + + Args: + irreps (Irreps): irreps. + segments (list of jax.Array): segments. + shape (tuple of int): shape of the final array. + layout (IrrepsLayout): data layout. + dtype: data type + + Returns: + RepArray: the RepArray. + + Examples: + + >>> cuex.from_segments( + ... cue.Irreps("SO3", "2x0 + 1"), + ... [jnp.array([[1.0], [2.0]]), jnp.array([[0.0], [0.0], [0.0]])], + ... (-1,), cue.ir_mul) + {0: 2x0+1} [1. 2. 0. 0. 0.] + """ + irreps = cue.Irreps(irreps) + shape = list(shape) + shape[-1] = irreps.dim + + if not all(x.ndim == len(shape) + 1 for x in segments): + raise ValueError( + "from_segments: segments must have ndim equal to len(shape) + 1" + ) + + if len(segments) != len(irreps): + raise ValueError( + f"from_segments: the number of segments {len(segments)} must match the number of irreps {len(irreps)}" + ) + + if dtype is not None: + segments = [segment.astype(dtype) for segment in segments] + + segments = [ + segment.reshape(segment.shape[:-2] + (mul * ir.dim,)) + for (mul, ir), segment in zip(irreps, segments) + ] + + if len(segments) > 0: + array = jnp.concatenate(segments, axis=-1) + else: + array = jnp.zeros(shape, dtype=dtype) + + return RepArray(irreps, array, layout) + + +class _MulIndexSliceHelper: + irreps_array: RepArray + + def __init__(self, irreps_array: RepArray): + assert irreps_array.is_irreps_array() + self.irreps_array = irreps_array + + def __getitem__(self, index: slice) -> RepArray: + if not isinstance(index, slice): + raise IndexError( + "RepArray.slice_by_mul only supports one slices (like RepArray.slice_by_mul[2:4])." + ) + + input_irreps = self.irreps_array.irreps + start, stop, stride = index.indices(input_irreps.num_irreps) + if stride != 1: + raise NotImplementedError("RepArray.slice_by_mul does not support strides.") + + output_irreps = [] + segments = [] + i = 0 + for (mul, ir), x in zip(input_irreps, self.irreps_array.segments): + if start <= i and i + mul <= stop: + output_irreps.append((mul, ir)) + segments.append(x) + elif start < i + mul and i < stop: + output_irreps.append((min(stop, i + mul) - max(start, i), ir)) + match self.irreps_array.layout: + case cue.mul_ir: + segments.append( + x[..., slice(max(start, i) - i, min(stop, i + mul) - i), :] + ) + case cue.ir_mul: + segments.append( + x[..., slice(max(start, i) - i, min(stop, i + mul) - i)] + ) + + i += mul + + return from_segments( + cue.Irreps(input_irreps.irrep_class, output_irreps), + segments, + self.irreps_array.shape, + self.irreps_array.layout, + self.irreps_array.dtype, + ) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py b/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py new file mode 100644 index 0000000..f7ce2e8 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np # noqa: F401 + +import cuequivariance as cue +import cuequivariance_jax as cuex +from cuequivariance.irreps_array.misc_ui import assert_same_group + + +def concatenate(arrays: list[cuex.RepArray]) -> cuex.RepArray: + """Concatenate a list of :class:`cuex.RepArray ` + + Args: + arrays (list of RepArray): List of arrays to concatenate. + axis (int, optional): Axis along which to concatenate. Defaults to -1. + + Example: + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("3x0", jnp.array([1.0, 2.0, 3.0])) + ... y = cuex.RepArray("1x1", jnp.array([0.0, 0.0, 0.0])) + >>> cuex.concatenate([x, y]) + {0: 3x0+1} [1. 2. 3. 0. 0. 0.] + """ + if len(arrays) == 0: + raise ValueError( + "Must provide at least one array to concatenate" + ) # pragma: no cover + if not all(a.layout == arrays[0].layout for a in arrays): + raise ValueError("All arrays must have the same layout") # pragma: no cover + if not all(a.ndim == arrays[0].ndim for a in arrays): + raise ValueError( + "All arrays must have the same number of dimensions" + ) # pragma: no cover + assert_same_group(*[a.irreps for a in arrays]) + + irreps = sum( + (a.irreps for a in arrays), cue.Irreps(arrays[0].irreps.irrep_class, []) + ) + return cuex.RepArray( + irreps, + jnp.concatenate([a.array for a in arrays], axis=-1), + arrays[0].layout, + ) + + +def randn( + key: jax.Array, + rep: cue.Rep, + leading_shape: tuple[int, ...] = (), + dtype: jnp.dtype | None = None, +) -> cuex.RepArray: + r"""Generate a random :class:`cuex.RepArray `. + + Args: + key (jax.Array): Random key. + rep (Rep): representation. + leading_shape (tuple[int, ...], optional): Leading shape of the array. Defaults to (). + dtype (jnp.dtype): Data type of the array. + + Returns: + RepArray: Random RepArray. + + Example: + + >>> key = jax.random.key(0) + >>> rep = cue.IrrepsAndLayout(cue.Irreps("O3", "2x1o"), cue.ir_mul) + >>> cuex.randn(key, rep, ()) + {0: 2x1o} [...] + """ + return cuex.RepArray( + rep, jax.random.normal(key, leading_shape + (rep.dim,), dtype=dtype) + ) + + +def as_irreps_array( + input: Any, + layout: cue.IrrepsLayout | None = None, + like: cuex.RepArray | None = None, +) -> cuex.RepArray: + """Converts input to a `RepArray`. Arrays are assumed to be scalars. + + Examples: + + >>> with cue.assume(cue.O3): + ... cuex.as_irreps_array([1.0], layout=cue.ir_mul) + {0: 0e} [1.] + """ + ir = None + + if like is not None: + assert layout is None + assert like.is_irreps_array() + + layout = like.layout + ir = like.irreps.irrep_class.trivial() + del like + + if layout is None: + layout = cue.get_layout_scope() + if ir is None: + ir = cue.get_irrep_scope().trivial() + + if isinstance(input, cuex.RepArray): + assert input.is_irreps_array() + + if input.layout != layout: + raise ValueError( + f"as_irreps_array: layout mismatch {input.layout} != {layout}" + ) + + return input + + input: jax.Array = jnp.asarray(input) + irreps = cue.Irreps(type(ir), [(input.shape[-1], ir)]) + return cuex.RepArray(irreps, input, layout) + + +def clebsch_gordan(rep1: cue.Irrep, rep2: cue.Irrep, rep3: cue.Irrep) -> cuex.RepArray: + r""" + Compute the Clebsch-Gordan coefficients. + + The Clebsch-Gordan coefficients are used to decompose the tensor product of two irreducible representations + into a direct sum of irreducible representations. This method computes the Clebsch-Gordan coefficients + for the given input representations and returns an array of shape ``(num_solutions, dim1, dim2, dim3)``, + where num_solutions is the number of solutions, ``dim1`` is the dimension of ``rep1``, ``dim2`` is the + dimension of ``rep2``, and ``dim3`` is the dimension of ``rep3``. + + The Clebsch-Gordan coefficients satisfy the following equation: + + .. math:: + + C_{ljk} X^1_{li} + C_{ilk} X^2_{lj} = X^3_{kl} C_{ijl} + + Args: + rep1 (Irrep): The first irreducible representation (input). + rep2 (Irrep): The second irreducible representation (input). + rep3 (Irrep): The third irreducible representation (output). + + Returns: + RepArray: An array of shape ``(num_solutions, dim1, dim2, dim3)``. + + Examples: + >>> rep1 = cue.SO3(1) + >>> rep2 = cue.SO3(1) + >>> rep3 = cue.SO3(2) + >>> C1 = cuex.clebsch_gordan(rep1, rep2, rep3) + >>> C1.shape + (1, 3, 3, 5) + + According to the definition of the Clebsch-Gordan coefficients, the following transformation should be identity: + >>> C2 = C1.transform(jnp.array([0.1, -0.3, 0.4])) + >>> np.testing.assert_allclose(C1.array, C2.array, atol=1e-3) + """ + return cuex.RepArray( + {1: rep1, 2: rep2, 3: rep3}, cue.clebsch_gordan(rep1, rep2, rep3) + ) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/vmap.py b/cuequivariance_jax/cuequivariance_jax/rep_array/vmap.py new file mode 100644 index 0000000..0e32af8 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/vmap.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, TypeVar + +import jax + +import cuequivariance as cue +import cuequivariance_jax as cuex + +T = TypeVar("T") + + +def remove_axis(dirreps: dict[int, T], axis: int) -> dict[int, T]: + assert axis >= 0 + if axis in dirreps: + raise ValueError( + f"Cannot vmap over an Irreps axis. {axis} has Irreps {dirreps[axis]}." + ) + return { + a - 1 if a > axis else a: irreps for a, irreps in dirreps.items() if a != axis + } + + +def add_axis(dirreps: dict[int, T], axis: int) -> dict[int, T]: + return {a + 1 if a >= axis else a: irreps for a, irreps in dirreps.items()} + + +def vmap( + fun: Callable[..., Any], + in_axes: int | tuple[int, ...] = 0, + out_axes: int = 0, +) -> Callable[..., Any]: + """ + Like jax.vmap, but for RepArray. + + Args: + fun: Callable[..., Any]: Function to vectorize. Can take `RepArray` as input and output. + in_axes: int | tuple[int, ...]: Axes to vectorize over. + out_axes: int: Axes to vectorize over. + + Returns: + Callable[..., Any]: Vectorized function. + """ + + def inside_fun(*args, **kwargs): + args, kwargs = jax.tree.map( + lambda x: x.to_array() if isinstance(x, _wrapper) else x, + (args, kwargs), + is_leaf=lambda x: isinstance(x, _wrapper), + ) + out = fun(*args, **kwargs) + return jax.tree.map( + lambda x: ( + _wrapper.from_array_add_axis(x, out_axes) if _is_array(x) else x + ), + out, + is_leaf=_is_array, + ) + + def outside_fun(*args, **kwargs): + if isinstance(in_axes, int): + in_axes_ = (in_axes,) * len(args) + else: + in_axes_ = in_axes + + args = [ + jax.tree.map( + lambda x: ( + _wrapper.from_array_remove_axis(x, axis) if _is_array(x) else x + ), + arg, + is_leaf=_is_array, + ) + for axis, arg in zip(in_axes_, args) + ] + kwargs = jax.tree.map( + lambda x: (_wrapper.from_array_remove_axis(x, 0) if _is_array(x) else x), + kwargs, + is_leaf=_is_array, + ) + out = jax.vmap(inside_fun, in_axes, out_axes)(*args, **kwargs) + return jax.tree.map( + lambda x: x.to_array() if isinstance(x, _wrapper) else x, + out, + is_leaf=lambda x: isinstance(x, _wrapper), + ) + + return outside_fun + + +def _is_array(x): + return isinstance(x, cuex.RepArray) + + +@dataclass(frozen=True) +class _wrapper: + reps: dict[int, cue.Rep] = field() + array: jax.Array = field() + + def to_array(self): + return cuex.RepArray(self.reps, self.array) + + @classmethod + def from_array_add_axis(cls, x: cuex.RepArray, axis: int) -> _wrapper: + return _wrapper(add_axis(x.reps, axis), x.array) + + @classmethod + def from_array_remove_axis(cls, x: cuex.RepArray, axis: int) -> _wrapper: + return _wrapper( + remove_axis(x.reps, axis if axis >= 0 else axis + x.ndim), + x.array, + ) + + +jax.tree_util.register_pytree_node( + _wrapper, + lambda x: ((x.array,), (x.reps,)), + lambda static, data: _wrapper(static[0], data[0]), +) diff --git a/cuequivariance_jax/tests/flax_linen/linear_test.py b/cuequivariance_jax/tests/flax_linen/linear_test.py index 0d31d7f..73acb66 100644 --- a/cuequivariance_jax/tests/flax_linen/linear_test.py +++ b/cuequivariance_jax/tests/flax_linen/linear_test.py @@ -28,12 +28,12 @@ def test_explicit_linear(layout_in, layout_out): except ImportError: pytest.skip("flax not installed") - x = cuex.IrrepsArray(cue.Irreps("SO3", "3x0 + 2x1"), jnp.ones((16, 9)), layout_in) + x = cuex.RepArray(cue.Irreps("SO3", "3x0 + 2x1"), jnp.ones((16, 9)), layout_in) linear = cuex.flax_linen.Linear(cue.Irreps("SO3", "2x0 + 1"), layout_out) w = linear.init(jax.random.key(0), x) - y: cuex.IrrepsArray = linear.apply(w, x) + y: cuex.RepArray = linear.apply(w, x) assert y.shape == (16, 5) - assert y.irreps() == cue.Irreps("SO3", "2x0 + 1") + assert y.irreps == cue.Irreps("SO3", "2x0 + 1") assert y.layout == layout_out @@ -44,9 +44,9 @@ def test_implicit_linear(): except ImportError: pytest.skip("flax not installed") - x = cuex.IrrepsArray("3x0 + 2x1", jnp.ones((16, 9))) + x = cuex.RepArray("3x0 + 2x1", jnp.ones((16, 9))) linear = cuex.flax_linen.Linear("2x0 + 1") w = linear.init(jax.random.key(0), x) y = linear.apply(w, x) assert y.shape == (16, 5) - assert y.irreps() == "2x0 + 1" + assert y.irreps == "2x0 + 1" diff --git a/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py b/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py index 59bc2c8..b8de74f 100644 --- a/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py +++ b/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py @@ -20,20 +20,20 @@ @cue.assume("SO3", cue.ir_mul) def test_segments(): - x = cuex.IrrepsArray("2x0 + 1", jnp.array([1.0, 1.0, 0.0, 0.0, 0.0])) - x0, x1 = x.segments() + x = cuex.RepArray("2x0 + 1", jnp.array([1.0, 1.0, 0.0, 0.0, 0.0])) + x0, x1 = x.segments assert x0.shape == (1, 2) assert x1.shape == (3, 1) y = cuex.from_segments("2x0 + 1", [x0, x1], x.shape) - assert x.dirreps == y.dirreps + assert x.irreps == y.irreps assert x.layout == y.layout assert jnp.allclose(x.array, y.array) @cue.assume("SO3", cue.ir_mul) def test_slice_by_mul(): - x = cuex.IrrepsArray("2x0 + 1", jnp.array([1.0, 1.0, 0.0, 0.0, 0.0])) - x = x.slice_by_mul()[1:] - assert x.dirreps == {0: cue.Irreps("0 + 1")} + x = cuex.RepArray("2x0 + 1", jnp.array([1.0, 1.0, 0.0, 0.0, 0.0])) + x = x.slice_by_mul[1:] + assert x.irreps == cue.Irreps("0 + 1") assert x.layout == cue.ir_mul assert jnp.allclose(x.array, jnp.array([1.0, 0.0, 0.0, 0.0])) diff --git a/cuequivariance_jax/tests/operations/spherical_harmonics_test.py b/cuequivariance_jax/tests/operations/spherical_harmonics_test.py index fa95572..903fecd 100644 --- a/cuequivariance_jax/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_jax/tests/operations/spherical_harmonics_test.py @@ -25,16 +25,7 @@ @pytest.mark.parametrize("shape", [(2, 3), ()]) def test_spherical_harmonics(shape): - x = cuex.IrrepsArray( - cue.Irreps(cue.O3, "1o"), np.random.randn(*shape, 3), cue.ir_mul - ) + x = cuex.RepArray(cue.Irreps(cue.O3, "1o"), np.random.randn(*shape, 3), cue.ir_mul) y = cuex.spherical_harmonics([0, 1, 2], x) assert y.shape == shape + (9,) - assert y.irreps() == cue.Irreps(cue.O3, "0e + 1o + 2e") - - -# def test_edge_case(): -# x = cuex.IrrepsArray(cue.Irreps(cue.O3, "1o"), np.random.randn(2, 2, 3), cue.ir_mul) -# y = cuex.spherical_harmonics([0], x) -# assert y.shape == (2, 2, 1) -# assert y.irreps() == cue.Irreps(cue.O3, "0e") + assert y.irreps == cue.Irreps(cue.O3, "0e + 1o + 2e") diff --git a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py index 766cffd..020e96b 100644 --- a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py @@ -30,7 +30,7 @@ def test_special_double_backward(): h1 = lambda w, x: jax.grad(h0, 1)(w, x).array.sum() ** 2 # noqa w = jax.random.normal(jax.random.key(0), (1, irreps_w.dim)) - x = cuex.IrrepsArray( + x = cuex.RepArray( irreps_x, jax.random.normal(jax.random.key(1), (3, irreps_x.dim)), cue.ir_mul ) jax.grad(h1, 0)(w, x) diff --git a/cuequivariance_jax/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_jax/tests/primitives/symmetric_tensor_product_test.py index 6cbb453..05b43ad 100644 --- a/cuequivariance_jax/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/symmetric_tensor_product_test.py @@ -30,8 +30,8 @@ def test_custom_jvp(): 3 * cue.Irreps(cue.O3, "0e"), [0, 1, 2, 3, 4], ) - w = np.random.randn(2, e.inputs[0].irreps.dim) - x = np.random.randn(2, e.inputs[1].irreps.dim) + w = np.random.randn(2, e.inputs[0].dim) + x = np.random.randn(2, e.inputs[1].dim) A = jax.grad( lambda x: jnp.sum( diff --git a/cuequivariance_jax/tests/vmap_test.py b/cuequivariance_jax/tests/vmap_test.py index a521a41..6e4452b 100644 --- a/cuequivariance_jax/tests/vmap_test.py +++ b/cuequivariance_jax/tests/vmap_test.py @@ -24,7 +24,7 @@ def test_vmap(): def f(x): return x - x = cuex.IrrepsArray({0: "1"}, jnp.zeros((3, 2))) + x = cuex.RepArray({0: "1"}, jnp.zeros((3, 2))) y = jax.jit(cuex.vmap(f, 1, 0))(x) assert y.shape == (2, 3) - assert y.dirreps == {1: cue.Irreps("1")} + assert y.reps == {1: cue.IrrepsAndLayout("1")} diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 41c7e89..33788c9 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -66,7 +66,7 @@ def __init__( self.irreps_in = irreps_in self.irreps_out = irreps_out - self.weight_numel = e.inputs[0].irreps.dim + self.weight_numel = e.inputs[0].dim self.shared_weights = shared_weights self.internal_weights = ( diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index a91a72a..7f41775 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -138,7 +138,7 @@ def __init__( self.weight_shape = (p.shape[0], mul) else: self.projection = None - self.weight_shape = (self.etp.inputs[0].irreps.dim // mul, mul) + self.weight_shape = (self.etp.inputs[0].dim // mul, mul) self.num_elements = num_elements self.weight = torch.nn.Parameter( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 123c044..5e99746 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -128,16 +128,16 @@ class EquivariantTensorProduct(torch.nn.Module): >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) - >>> w = torch.ones(e.inputs[0].irreps.dim, device=device) - >>> x1 = torch.ones(17, e.inputs[1].irreps.dim, device=device) - >>> x2 = torch.ones(17, e.inputs[2].irreps.dim, device=device) + >>> w = torch.ones(e.inputs[0].dim, device=device) + >>> x1 = torch.ones(17, e.inputs[1].dim, device=device) + >>> x2 = torch.ones(17, e.inputs[2].dim, device=device) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) >>> tp([w, x1, x2]) tensor([[0., 0., 0., 0., 0., 0.],...) You can optionally index the first input tensor: - >>> w = torch.ones(3, e.inputs[0].irreps.dim, device=device) + >>> w = torch.ones(3, e.inputs[0].dim, device=device) >>> indices = torch.randint(3, (17,)) >>> tp([w, x1, x2], indices=indices) tensor([[0., 0., 0., 0., 0., 0.],...) @@ -164,36 +164,44 @@ def __init__( raise ValueError( f"Expected {e.num_inputs} input layouts, got {len(layout_in)}" ) - layout_in = tuple(ell or layout for ell in layout_in) + layout_in = tuple(lay or layout for lay in layout_in) layout_out = layout_out or layout del layout self.etp = e - self.layout_in = layout_in = tuple(map(default_layout, layout_in)) - self.layout_out = layout_out = default_layout(layout_out) transpose_in = torch.nn.ModuleList() for layout_used, input_expected in zip(layout_in, e.inputs): - transpose_in.append( - cuet.TransposeIrrepsLayout( - input_expected.irreps, - source=layout_used, - target=input_expected.layout, - device=device, - use_fallback=use_fallback, + if isinstance(input_expected, cue.IrrepsAndLayout): + layout_used = default_layout(layout_used) + transpose_in.append( + cuet.TransposeIrrepsLayout( + input_expected.irreps, + source=layout_used, + target=input_expected.layout, + device=device, + use_fallback=use_fallback, + ) ) - ) + else: + assert layout_used is None + transpose_in.append(torch.nn.Identity()) # script() requires literal addressing and fails to eliminate dead branches self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in) - 1](transpose_in) - self.transpose_out = cuet.TransposeIrrepsLayout( - e.output.irreps, - source=e.output.layout, - target=layout_out, - device=device, - use_fallback=use_fallback, - ) + if isinstance(e.output, cue.IrrepsAndLayout): + layout_out = default_layout(layout_out) + self.transpose_out = cuet.TransposeIrrepsLayout( + e.output.irreps, + source=e.output.layout, + target=layout_out, + device=device, + use_fallback=use_fallback, + ) + else: + assert layout_out is None + self.transpose_out = torch.nn.Identity() if ( len(e.ds) > 1 @@ -235,7 +243,7 @@ def __init__( ) ) - self.operands_dims = [op.irreps.dim for op in e.operands] + self.operands_dims = [op.dim for op in e.operands] def extra_repr(self) -> str: return str(self.etp) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 043e387..071cb2a 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -283,8 +283,7 @@ def test_performance_cuda_vs_fx( ) inputs = [ - torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) - for inp in e.inputs + torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] for _ in range(10): @@ -333,8 +332,7 @@ def test_precision_cuda_vs_fx( pytest.skip("CUDA is not available") inputs = [ - torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) - for inp in e.inputs + torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] m = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, use_fallback=False @@ -371,8 +369,7 @@ def test_compile( e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype ) inputs = [ - torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) - for inp in e.inputs + torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] res = m(inputs) m_compile = torch.compile(m, fullgraph=True) @@ -396,8 +393,7 @@ def test_script( e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype ) inputs = [ - torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) - for inp in e.inputs + torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] res = m(inputs) m_script = torch.jit.script(m) @@ -428,8 +424,7 @@ def test_export( e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device=device ) inputs = [ - torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) - for inp in e.inputs + torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs ] res = m(inputs) m_script = module_with_mode(mode, m, [inputs], math_dtype, tmp_path) diff --git a/docs/api/cuequivariance.rst b/docs/api/cuequivariance.rst index 7aba3b0..9a297e8 100644 --- a/docs/api/cuequivariance.rst +++ b/docs/api/cuequivariance.rst @@ -49,6 +49,7 @@ These classes represent tensor products. Irreps IrrepsLayout + IrrepsAndLayout SegmentedTensorProduct EquivariantTensorProduct @@ -61,3 +62,13 @@ Descriptors :hidden: cuequivariance.descriptors + +Utilities +--------- + +.. autosummary:: + :toctree: generated/ + :template: class_template.rst + + assume + \ No newline at end of file diff --git a/docs/api/cuequivariance_jax.rst b/docs/api/cuequivariance_jax.rst index 1dfd0b4..181c6e3 100644 --- a/docs/api/cuequivariance_jax.rst +++ b/docs/api/cuequivariance_jax.rst @@ -19,14 +19,14 @@ cuequivariance-jax ================== -IrrepsArray ------------ +RepArray +-------- .. autosummary:: :toctree: generated/ :template: class_template.rst - IrrepsArray + RepArray .. autosummary:: :toctree: generated/ @@ -56,3 +56,9 @@ Extra Modules flax_linen.Linear flax_linen.LayerNorm + +.. autosummary:: + :toctree: generated/ + :template: function_template.rst + + spherical_harmonics diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 529eda8..957cfd2 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -70,9 +70,9 @@ Execution on JAX cuex.equivariant_tensor_product(e, w, x) -The function :func:`cuex.randn ` generates random :class:`cuex.IrrepsArray ` objects. +The function :func:`cuex.randn ` generates random :class:`cuex.RepArray ` objects. The function :func:`cuex.equivariant_tensor_product ` executes the tensor product. -The output is a :class:`cuex.IrrepsArray ` object. +The output is a :class:`cuex.RepArray ` object. Execution on PyTorch @@ -89,10 +89,10 @@ We can execute an :class:`cuequivariance.EquivariantTensorProduct` with PyTorch. cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "8x0e + 4x1o") ) - module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) + module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True) - w = torch.randn(e.inputs[0].irreps.dim) - x = torch.randn(e.inputs[1].irreps.dim) + w = torch.randn(e.inputs[0].dim) + x = torch.randn(e.inputs[1].dim) module([w, x]) diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index b9516ee..c76fb4f 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -94,7 +94,7 @@ Now we can create a tensor product from the descriptor and execute it. In PyTorc .. jupyter-execute:: - linear_torch = cuet.TensorProduct(d) + linear_torch = cuet.TensorProduct(d, use_fallback=True) linear_torch