-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdecompose.py
56 lines (38 loc) · 1.33 KB
/
decompose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import tensorly as tl
from tensorly import unfold
from tensorly.decomposition import tucker
tl.set_backend("pytorch")
def get_num_unfold(decomposer):
"""Get number of unfolding based on decomposer"""
return 1 if decomposer == "svd" else 3
def get_u_svd(x, rank=1):
"""Perform SVD, return rank first columns of u"""
u, _, _ = torch.linalg.svd(x, full_matrices=False)
return u[:, :rank]
def svd(x, rank=1):
u, _, vh = torch.linalg.svd(x, full_matrices=False)
v = torch.transpose(vh, 0, 1)
factors = [u[:, :rank], v[:, :rank]]
return factors
def decompose(x, decomposer="hosvd", rank=1, mode=0):
"""Get left_singulars based on decomposer and rank
svd: [u, v]
hosvd: [u1, u2, u3]
tucker: [u1, u2, u3]
"""
if decomposer == "tucker":
_, factors = tucker(x, rank=[rank, rank, rank])
return factors
elif decomposer == "hosvd":
num_unfold = get_num_unfold(decomposer)
left_singulars = []
for i in range(num_unfold):
unfold_i = unfold(x, i)
left_singular_i = get_u_svd(unfold_i, rank=rank)
left_singulars.append(left_singular_i)
return left_singulars
elif decomposer == "svd":
unfold_mode = unfold(x, mode)
factors = svd(unfold_mode, rank=rank)
return factors