diff --git a/cuequivariance/cuequivariance/descriptors/__init__.py b/cuequivariance/cuequivariance/descriptors/__init__.py index c9592cb..af6e6e3 100644 --- a/cuequivariance/cuequivariance/descriptors/__init__.py +++ b/cuequivariance/cuequivariance/descriptors/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from .transposition import transpose from .irreps_tp import ( + full_tensor_product, fully_connected_tensor_product, channelwise_tensor_product, elementwise_tensor_product, @@ -35,6 +36,7 @@ __all__ = [ "transpose", + "full_tensor_product", "fully_connected_tensor_product", "channelwise_tensor_product", "elementwise_tensor_product", diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/descriptors/irreps_tp.py index 1efdf2a..631e4c0 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/descriptors/irreps_tp.py @@ -80,6 +80,62 @@ def fully_connected_tensor_product( ) +def full_tensor_product( + irreps1: cue.Irreps, + irreps2: cue.Irreps, + irreps3_filter: Optional[Sequence[cue.Irrep]] = None, +) -> cue.EquivariantTensorProduct: + """ + subscripts: ``lhs[iu],rhs[jv],output[kuv]`` + + Construct a weightless channelwise tensor product descriptor. + + .. currentmodule:: cuequivariance + + Args: + irreps1 (Irreps): Irreps of the first operand. + irreps2 (Irreps): Irreps of the second operand. + irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. + + Returns: + EquivariantTensorProduct: Descriptor of the full tensor product. + """ + G = irreps1.irrep_class + + if irreps3_filter is not None: + irreps3_filter = into_list_of_irrep(G, irreps3_filter) + + d = stp.SegmentedTensorProduct.from_subscripts("iu,jv,kuv+ijk") + + for mul, ir in irreps1: + d.add_segment(0, (ir.dim, mul)) + for mul, ir in irreps2: + d.add_segment(1, (ir.dim, mul)) + + irreps3 = [] + + for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product( + enumerate(irreps1), enumerate(irreps2) + ): + for ir3 in ir1 * ir2: + # for loop over the different solutions of the Clebsch-Gordan decomposition + for cg in cue.clebsch_gordan(ir1, ir2, ir3): + d.add_path(i1, i2, None, c=cg) + + irreps3.append((mul1 * mul2, ir3)) + + irreps3 = cue.Irreps(G, irreps3) + irreps3, perm, inv = irreps3.sort() + d = d.permute_segments(2, inv) + + d = d.normalize_paths_for_operand(-1) + return cue.EquivariantTensorProduct( + d, + [irreps1, irreps2, irreps3], + layout=cue.ir_mul, + ) + + def channelwise_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, diff --git a/cuequivariance/tests/equivariant_tensor_products_test.py b/cuequivariance/tests/equivariant_tensor_products_test.py index a6f00bd..86126bb 100644 --- a/cuequivariance/tests/equivariant_tensor_products_test.py +++ b/cuequivariance/tests/equivariant_tensor_products_test.py @@ -31,6 +31,12 @@ def test_commutativity_squeeze_flatten(): == d.flatten_coefficient_modes().squeeze_modes() ) + d = descriptors.full_tensor_product(irreps1, irreps2, irreps3).d + assert ( + d.squeeze_modes().flatten_coefficient_modes() + == d.flatten_coefficient_modes().squeeze_modes() + ) + d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d assert ( d.squeeze_modes().flatten_coefficient_modes()