-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_serialization.py
60 lines (51 loc) · 1.71 KB
/
test_serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import io
import sten
import scipy
def test_tensor_serialization():
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
grad_fmt = (
sten.KeepAll(),
torch.Tensor,
sten.RandomFractionSparsifier(0.5),
sten.CscTensor,
)
sx = sten.random_fraction_sparsifier_dense_csc(
sten.RandomFractionSparsifier(0.5),
x,
grad_fmt=(
sten.KeepAll(),
torch.Tensor,
sten.RandomFractionSparsifier(0.5),
sten.CscTensor,
),
)
fp = io.BytesIO()
torch.save(sx, fp)
fp.seek(0)
lsx = torch.load(fp)
assert type(lsx) == sten.SparseTensorWrapper
assert type(lsx.wrapped_tensor) == sten.CscTensor
assert type(lsx.wrapped_tensor.data) == scipy.sparse.csc_matrix
assert lsx.grad_fmt == grad_fmt
def test_module_weights_serialization():
sten.set_dispatch_failure("warn")
model = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8)
sb = sten.SparsityBuilder()
sb.set_weight(
name="linear1.weight",
initial_sparsifier=sten.ScalarFractionSparsifier(0.5),
out_format=sten.CooTensor,
)
sparse_model = sb.get_sparse_model(model)
assert isinstance(sparse_model.linear1.weight, sten.SparseParameterWrapper)
assert isinstance(sparse_model.linear1.weight.wrapped_tensor, sten.CooTensor)
fp = io.BytesIO()
torch.save(sparse_model, fp)
fp.seek(0)
l_sparse_model = torch.load(fp)
assert isinstance(l_sparse_model.linear1.weight, sten.SparseParameterWrapper)
assert isinstance(l_sparse_model.linear1.weight.wrapped_tensor, sten.CooTensor)
if __name__ == "__main__":
test_tensor_serialization()
test_module_weights_serialization()