-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbench.py
63 lines (59 loc) · 2.05 KB
/
bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import triton
from e_natten import natten2d
from natten.functional import natten2dqk, natten2dav
@triton.testing.perf_report([
triton.testing.Benchmark(
x_names=['D'],
x_vals=[2**i for i in range(4, 10)],
line_arg="provider",
line_vals=['triton', 'natten'],
line_names=['e_natten (fused)', 'natten (original)'],
ylabel='time (ms)',
args={'B': 4, 'N': 4, 'C': 128, 'kernel_size': 5},
plot_name=f"2d-fwd",
)
])
def bench_2d_fwd(B, N, D, C, kernel_size, provider):
q, k, v = torch.randn((3, B, N, D, D, C)).cuda()
if provider == 'triton':
fn = lambda: natten2d(q, k, v, kernel_size)
elif provider == 'natten':
fn = lambda: natten2dav(torch.softmax(natten2dqk(q, k, kernel_size, 1), dim=-1), v, kernel_size, 1)
warmup = 200
rep = 1000
with torch.no_grad():
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
@triton.testing.perf_report([
triton.testing.Benchmark(
x_names=['D'],
x_vals=[2**i for i in range(4, 9)],
line_arg="provider",
line_vals=['triton', 'natten'],
line_names=['e_natten (fused)', 'natten (original)'],
ylabel='time (ms)',
args={'B': 4, 'N': 4, 'C': 128, 'kernel_size': 5},
plot_name=f"2d-bwd",
)
])
def bench_2d_bwd(B, N, D, C, kernel_size, provider):
q, k, v = torch.randn((3, B, N, D, D, C)).requires_grad_().cuda()
if provider == 'triton':
def fn():
out = natten2d(q, k, v, kernel_size)
loss = torch.sum(out ** 2)
loss.backward()
elif provider == 'natten':
def fn():
s = natten2dqk(q, k, kernel_size, 1)
p = torch.softmax(s, dim=-1)
o = natten2dav(p, v, kernel_size, 1)
loss = torch.sum(o ** 2)
loss.backward()
warmup = 100
rep = 200
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
bench_2d_fwd.run(save_path="assets", print_data=True)
bench_2d_bwd.run(save_path="assets", print_data=True)