Skip to content

Commit f793a22

Browse files
committed
Apply review comments
1 parent 6e8746c commit f793a22

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

Diff for: src/finch/linalg/_linalg.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from numpy.core.numeric import normalize_axis_tuple
2+
13
from ..julia import jl
24
from ..tensor import Tensor
35

@@ -11,12 +13,14 @@ def vector_norm(
1113
ord: int | float = 2,
1214
) -> Tensor:
1315
if axis is not None:
14-
raise ValueError(
15-
"At the moment only `None` (vector norm of a flattened array) "
16-
"is supported. Got: {axis}."
17-
)
16+
axis = normalize_axis_tuple(axis, x.ndim)
17+
if axis != tuple(range(x.ndim)):
18+
raise ValueError(
19+
"At the moment only `None` (vector norm of a flattened array) "
20+
"is supported. Got: {axis}."
21+
)
1822

1923
result = Tensor(jl.Finch.norm(x._obj, ord))
2024
if keepdims:
21-
result = result.__getitem__(tuple(None for _ in range(x.ndim)))
25+
result = result[tuple(None for _ in range(x.ndim))]
2226
return result

Diff for: src/finch/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def _order(self) -> tuple[int, ...]:
298298
@property
299299
def mT(self) -> "Tensor":
300300
axes = list(range(self.ndim))
301-
axes[-2:] = axes[:-3:-1]
301+
axes[-2], axes[-1] = axes[-1], axes[-2]
302302
axes = tuple(axes)
303303
return self.permute_dims(axes)
304304

0 commit comments

Comments
 (0)