Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

cue.IrrepsAndLayout, cue.EquivariantTensorProduct, cuex.RepArray #46

Merged
merged 24 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
### Added

- Partial support of `torch.jit.script` and `torch.compile`
- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `IrrepsArray`).

### Changed

- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input.
- `cuex.IrrepsArray` is now an alias for `cuex.RepArray` and its `.irreps` attribute and `.segments` are not functions anymore but properties.

## Removed

- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`.

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions cuequivariance/cuequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
IrrepsLayout,
mul_ir,
ir_mul,
IrrepsAndLayout,
get_layout_scope,
assume,
NumpyIrrepsArray,
Expand Down Expand Up @@ -71,6 +72,7 @@
"IrrepsLayout",
"mul_ir",
"ir_mul",
"IrrepsAndLayout",
"get_layout_scope",
"assume",
"NumpyIrrepsArray",
Expand Down
7 changes: 0 additions & 7 deletions cuequivariance/cuequivariance/descriptors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
yxy_rotation,
inversion,
)
from .escn import escn_tp, escn_tp_compact
from .spherical_harmonics_ import sympy_spherical_harmonics, spherical_harmonics
from .gatr import gatr_linear, gatr_geometric_product, gatr_outer_product

__all__ = [
"transpose",
Expand All @@ -49,11 +47,6 @@
"yx_rotation",
"yxy_rotation",
"inversion",
"escn_tp",
"escn_tp_compact",
"sympy_spherical_harmonics",
"spherical_harmonics",
"gatr_linear",
"gatr_geometric_product",
"gatr_outer_product",
]
37 changes: 28 additions & 9 deletions cuequivariance/cuequivariance/descriptors/irreps_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def fully_connected_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -131,8 +135,11 @@ def full_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -193,8 +200,12 @@ def channelwise_tensor_product(
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d,
[irreps1.new_scalars(d.operands[0].size), irreps1, irreps2, irreps3],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -273,7 +284,12 @@ def elementwise_tensor_product(
irreps3 = cue.Irreps(G, irreps3)
d = d.normalize_paths_for_operand(-1)
return cue.EquivariantTensorProduct(
d, [irreps1, irreps2, irreps3], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps1, cue.ir_mul),
cue.IrrepsAndLayout(irreps2, cue.ir_mul),
cue.IrrepsAndLayout(irreps3, cue.ir_mul),
],
)


Expand Down Expand Up @@ -308,6 +324,9 @@ def linear(

return cue.EquivariantTensorProduct(
d,
[irreps_in.new_scalars(d.operands[0].size), irreps_in, irreps_out],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps_in, cue.ir_mul),
cue.IrrepsAndLayout(irreps_out, cue.ir_mul),
],
)
59 changes: 39 additions & 20 deletions cuequivariance/cuequivariance/descriptors/rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def fixed_axis_angle_rotation(
)

d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul)
return cue.EquivariantTensorProduct(
d,
[
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


def yxy_rotation(
Expand Down Expand Up @@ -70,13 +76,12 @@ def yxy_rotation(
return cue.EquivariantTensorProduct(
cbaio,
[
irreps.new_scalars(cbaio.operands[0].size),
irreps.new_scalars(cbaio.operands[1].size),
irreps.new_scalars(cbaio.operands[2].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[2].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand All @@ -95,12 +100,11 @@ def xy_rotation(
return cue.EquivariantTensorProduct(
cbio,
[
irreps.new_scalars(cbio.operands[0].size),
irreps.new_scalars(cbio.operands[1].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand All @@ -119,12 +123,11 @@ def yx_rotation(
return cue.EquivariantTensorProduct(
cbio,
[
irreps.new_scalars(cbio.operands[0].size),
irreps.new_scalars(cbio.operands[1].size),
irreps,
irreps,
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
layout=cue.ir_mul,
)


Expand Down Expand Up @@ -188,7 +191,12 @@ def y_rotation(

d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(
d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


Expand All @@ -213,7 +221,12 @@ def x_rotation(
d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1))

return cue.EquivariantTensorProduct(
d, [irreps.new_scalars(d.operands[0].size), irreps, irreps], layout=cue.ir_mul
d,
[
cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)


Expand All @@ -228,4 +241,10 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct:
assert np.allclose(H @ H, np.eye(ir.dim), atol=1e-6)
d.add_path(None, None, c=H, dims={"u": mul})
d = d.flatten_coefficient_modes()
return cue.EquivariantTensorProduct(d, [irreps, irreps], layout=cue.ir_mul)
return cue.EquivariantTensorProduct(
d,
[
cue.IrrepsAndLayout(irreps, cue.ir_mul),
cue.IrrepsAndLayout(irreps, cue.ir_mul),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def spherical_harmonics(
indices = poly_degrees_to_path_indices(degrees)
d.add_path(*indices, i, c=coeff)

return cue.EquivariantTensorProduct([d], [ir_vec, ir], layout=layout)
return cue.EquivariantTensorProduct(
[d],
[
cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul),
cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul),
],
)


def poly_degrees_to_path_indices(degrees: tuple[int, ...]) -> tuple[int, ...]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def symmetric_contraction(
d = d.append_modes_to_all_operands("u", {"u": mul})
return cue.EquivariantTensorProduct(
[d],
[irreps_in.new_scalars(d.operands[0].size), mul * irreps_in, mul * irreps_out],
layout=cue.ir_mul,
[
cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul),
cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul),
cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul),
],
)
11 changes: 7 additions & 4 deletions cuequivariance/cuequivariance/descriptors/transposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cuequivariance as cue
from cuequivariance.equivariant_tensor_product import Operand


def transpose(
Expand All @@ -22,12 +21,16 @@ def transpose(
"""Transpose the irreps layout of a tensor."""
d = cue.SegmentedTensorProduct(
operands=[
cue.Operand(subscripts="ui" if source == cue.mul_ir else "iu"),
cue.Operand(subscripts="ui" if target == cue.mul_ir else "iu"),
cue.segmented_tensor_product.Operand(
subscripts="ui" if source == cue.mul_ir else "iu"
),
cue.segmented_tensor_product.Operand(
subscripts="ui" if target == cue.mul_ir else "iu"
),
]
)
for mul, ir in irreps:
d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim})
return cue.EquivariantTensorProduct(
d, [Operand(irreps, source), Operand(irreps, target)]
d, [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)]
)
Loading
Loading