Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[ADD] Implement shape of a structured matrix #81

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _update_preconditioner(self, module: Module):
# in `m_K, m_C`
scale = 0.5 * (1.0 - alpha1 if normalize_lr_cov else 1.0)

dim_K, dim_C = self.preconditioner_dims(module)
(dim_K,), (dim_C,) = set(K.shape), set(C.shape)
(dtype_K, dtype_C), dev = self._get_preconditioner_dtypes_and_device(module)

# step for m_K
Expand Down
12 changes: 11 additions & 1 deletion singd/structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.distributed as dist
from torch import Tensor, zeros
from torch import Size, Tensor, zeros
from torch.linalg import matrix_norm

from singd.structures.utils import diag_add_, supported_eye
Expand Down Expand Up @@ -82,6 +82,16 @@ def named_tensors(self) -> Iterator[Tuple[str, Tensor]]:
for name in self._tensor_names:
yield name, getattr(self, name)

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
self._warn_naive_implementation("shape")
return self.to_dense().shape

def __matmul__(
self, other: Union[StructuredMatrix, Tensor]
) -> Union[StructuredMatrix, Tensor]:
Expand Down
18 changes: 17 additions & 1 deletion singd/structures/blockdiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from einops import rearrange
from torch import Tensor, arange, cat, einsum, zeros
from torch import Size, Tensor, arange, cat, einsum, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -109,6 +109,22 @@ def __init__(self, blocks: Tensor, last: Tensor) -> None:
self._last: Tensor
self.register_tensor(last, "_last")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
num_blocks, _, _ = self._blocks.shape
last_rows, last_cols = self._last.shape
return Size(
(
num_blocks * self.BLOCK_DIM + last_rows,
num_blocks * self.BLOCK_DIM + last_cols,
)
)

@classmethod
def from_dense(cls, mat: Tensor) -> BlockDiagonalMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, einsum, ones, zeros
from torch import Size, Tensor, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -47,6 +47,15 @@ def __init__(self, mat_diag: Tensor) -> None:
self._mat_diag: Tensor
self.register_tensor(mat_diag, "_mat_diag")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
return self._mat_diag.shape + self._mat_diag.shape

def __matmul__(
self, other: Union[DiagonalMatrix, Tensor]
) -> Union[DiagonalMatrix, Tensor]:
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Tuple, Union

import torch
from torch import Tensor, arange, cat, einsum, ones, zeros
from torch import Size, Tensor, arange, cat, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -142,6 +142,15 @@ def __init__(self, A: Tensor, B: Tensor, C: Tensor, D: Tensor, E: Tensor):
self.E: Tensor
self.register_tensor(E, "E")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
return Size((self.dim, self.dim))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> HierarchicalMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
24 changes: 23 additions & 1 deletion singd/structures/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Iterator, List, Tuple, Type, Union

from torch import Tensor, block_diag
from torch import Size, Tensor, block_diag

from singd.structures.base import StructuredMatrix

Expand Down Expand Up @@ -184,6 +184,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
self.C: StructuredMatrix
self.register_substructure(C, "C")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
A_rows, A_cols = self.A.shape
C_rows, C_cols = self.C.shape
return Size((A_rows + C_rows, A_cols + C_cols))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> RecursiveTopRightMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down Expand Up @@ -322,6 +333,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
self.C: StructuredMatrix
self.register_substructure(C, "C")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
A_rows, A_cols = self.A.shape
C_rows, C_cols = self.C.shape
return Size((A_rows + C_rows, A_cols + C_cols))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> RecursiveBottomLeftMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/triltoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, arange, cat, zeros
from torch import Size, Tensor, arange, cat, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

Expand Down Expand Up @@ -55,6 +55,15 @@ def __init__(self, lower_diags: Tensor) -> None:
self._lower_diags: Tensor
self.register_tensor(lower_diags, "_lower_diags")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
return self._lower_diags.shape + self._lower_diags.shape

@classmethod
def from_dense(cls, mat: Tensor) -> TrilToeplitzMatrix:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/triutoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, arange, cat, triu_indices, zeros
from torch import Size, Tensor, arange, cat, triu_indices, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

Expand Down Expand Up @@ -55,6 +55,15 @@ def __init__(self, upper_diags: Tensor) -> None:
self._upper_diags: Tensor
self.register_tensor(upper_diags, "_upper_diags")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.

Returns:
The shape of the matrix.
"""
return self._upper_diags.shape + self._upper_diags.shape

@classmethod
def from_dense(cls, mat: Tensor) -> TriuToeplitzMatrix:
"""Construct from a PyTorch tensor.
Expand Down
15 changes: 15 additions & 0 deletions test/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,21 @@ def test_frobenius_norm(self, dev: device, dtype: torch.dtype):
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.frobenius_norm())

@mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)
@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_shape(self, dev: device, dtype: torch.dtype):
"""Test shape of a structured matrix.

Args:
dev: The device on which to run the test.
dtype: The data type of the matrices.
"""
for dim in self.DIMS:
manual_seed(0)
sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype))
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
assert sym_mat.shape == structured.shape

@mark.expensive
def test_visual(self):
"""Create pictures and animations of the structure.
Expand Down
Loading