-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_training.py
177 lines (138 loc) · 5.33 KB
/
test_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Following example https://pytorch.org/blog/accelerating-neural-network-training/
import torch
from torchao.sparsity.training import (
SemiSparseLinear,
swap_linear_with_semi_sparse_linear,
)
import sten
from transformers import AutoModel
class TrainingSparsifier:
def __init__(self):
pass
class TrainingTensor:
def __init__(self, weight):
self.weight = weight
def to_dense(self):
return self.weight
def add_(self, other, alpha=1):
self.weight.add_(other, alpha=alpha)
@property
def shape(self):
return self.weight.shape
@sten.register_sparsifier_implementation(
sparsifier=TrainingSparsifier, inp=torch.Tensor, out=TrainingTensor
)
def sparsifier_impl(sparsifier, tensor, grad_fmt=None):
return sten.SparseTensorWrapper.wrapped_from_dense(
TrainingTensor(tensor),
tensor,
grad_fmt,
)
@torch.compiler.allow_in_graph
def sparse_semi_structured_tile_wrapper(tensor):
return torch._sparse_semi_structured_tile(tensor, algorithm='', use_cutlass=False)
@sten.register_fwd_op_impl(
operator=torch.nn.functional.linear,
inp=(torch.Tensor, TrainingTensor, torch.Tensor),
out=None, # default (dense)
)
def operator_impl(ctx, inputs, output_sparsifiers):
x, w, bias = inputs
dense_weight = w.wrapped_tensor.weight
(packed, meta, packed_t, meta_t, bitmask) = sparse_semi_structured_tile_wrapper(dense_weight)
x_shape = x.shape
x2d = x.view(-1, x.shape[-1])
row, col = x2d.shape
x_padded = torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x2d)
# y = x * w.t -> y = (w * x.t).t
# grad_x = grad_y * w -> grad_x = (w.t * grad_y.t).t
result_padded = torch._cslt_sparse_mm(
packed,
x_padded.t(),
bias=bias,
transpose_result=True,#False,
alg_id=torch.sparse.semi_structured.SparseSemiStructuredTensor._DEFAULT_ALG_ID,
)#.t()
result = result_padded[:row, :]
ctx.save_for_backward(x, packed_t)
return result.view(*x_shape[:-1], -1)
@sten.register_bwd_op_impl(
operator=torch.nn.functional.linear,
grad_out=None, # default (dense)
grad_inp=None, # default (dense)
inp=(torch.Tensor, TrainingTensor, torch.Tensor),
)
def my_operator(ctx, grad_outputs, input_sparsifiers):
x, packed_t = ctx.saved_tensors
[grad_y] = grad_outputs
grad_y_shape = grad_y.shape
grad_y2d = grad_y.reshape(-1, grad_y.shape[-1])
x2d = x.view(-1, x.shape[-1])
grad_w = torch.mm(grad_y2d.t(), x2d)
# grad_x = w.t @ grad_y2d
row, col = grad_y2d.shape
grad_y2d_padded = torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(grad_y2d)
grad_x2d_padded = torch._cslt_sparse_mm(
packed_t,
grad_y2d_padded.t(),
bias=None,
transpose_result=True,#False,
alg_id=torch.sparse.semi_structured.SparseSemiStructuredTensor._DEFAULT_ALG_ID,
)#.t()
grad_x2d = grad_x2d_padded[:row, :]
grad_x = grad_x2d.view(*grad_y_shape[:-1], -1)
grad_b = grad_y.sum(dim=list(range(grad_y.ndim - 1)))
grad_inputs = (grad_x, grad_w, grad_b)
return grad_inputs
def run_parametrized(enable_compile, backend):
assert torch.cuda.is_available()
device = torch.device("cuda")
model = AutoModel.from_pretrained('facebook/dinov2-large').to(device).half()
model.train()
batch = 32
channels = 3
height = 224
width = 224
repeats = 10
def train_iteration(model, optimizer):
inputs = torch.randn((batch, channels, height, width), dtype=torch.half, device=device)
outputs = model(inputs)
loss = outputs.last_hidden_state.mean()
loss = loss * loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert backend in ['dense', 'torchao', 'sten']
if backend == 'dense':
pass
elif backend == 'torchao':
sparse_config = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
sparse_config[name] = SemiSparseLinear
swap_linear_with_semi_sparse_linear(model, sparse_config)
elif backend == 'sten':
sb = sten.SparsityBuilder()
for module_name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
weight = module_name + ".weight"
sb.set_weight(
name=weight,
initial_sparsifier=TrainingSparsifier(),
out_format=TrainingTensor,
)
sb.sparsify_model_inplace(model)
if enable_compile:
model = torch.compile(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
mean, std, times = sten.time_prof(repeats, lambda: train_iteration(model, optimizer), sync=torch.cuda.synchronize, warmup=0.3)
print(f"Runtime ({backend}, compile={enable_compile}) [ms] mean: {mean * 1e3:.3f}, std: {std * 1e3:.3f}, repeats: {len(times)}")
def test_training():
run_parametrized(False, 'dense')
run_parametrized(False, 'torchao')
run_parametrized(False, 'sten')
run_parametrized(True, 'dense')
run_parametrized(True, 'torchao')
run_parametrized(True, 'sten')
if __name__ == "__main__":
test_training()