-
Notifications
You must be signed in to change notification settings - Fork 2
/
gemm_i8.cuh
executable file
·485 lines (391 loc) · 18.7 KB
/
gemm_i8.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
#pragma once
#include <cuda.h>
#include <cuda/barrier>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <mma.h>
#include <limits>
constexpr bool GEMM_OP_T = true;
constexpr bool GEMM_OP_N = false;
using namespace nvcuda;
namespace cg = cooperative_groups;
namespace wmma_kernel{
template <int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, int WARP_SIZE_M, int WARP_SIZE_N, int STAGE, bool NoTransA, bool NoTransB, bool RowMajorC>
__global__ void GEMMI8TCU(const int8_t* A, const int8_t* B, int* C, int M, int N, int K)
{
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(block);
int warp_id = tile32.meta_group_rank();
int lane_id = tile32.thread_rank();
constexpr int WARP_SIZE = 32;
constexpr int TC_SIZE = 16;
constexpr int WAPR_NUM_N = BLOCK_SIZE_N / WARP_SIZE_N;
constexpr int WAPR_NUM_M = BLOCK_SIZE_M / WARP_SIZE_M;
constexpr int WAPR_NUM = WAPR_NUM_M * WAPR_NUM_N;
static_assert(NoTransA == GEMM_OP_T, "NoTransA == GEMM_OP_T");
static_assert(NoTransB == GEMM_OP_N, "NoTransB == GEMM_OP_N");
static_assert(RowMajorC == GEMM_OP_T, "RowMajorC == GEMM_OP_T");
__shared__ int8_t SLB[STAGE * (BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N)];
int8_t* smem_a[2];
int8_t* smem_b[2];
smem_a[0] = SLB;
smem_a[1] = SLB + BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[0] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[1] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N;
const int BCM = BLOCK_SIZE_M * blockIdx.y;
const int BCN = BLOCK_SIZE_N * blockIdx.x;
const int LDA = NoTransA ? K : M;
const int LDB = NoTransB ? N : K;
const int LDC = RowMajorC ? N : M;
const int WCM = warp_id / WAPR_NUM_N;
const int WCN = warp_id % WAPR_NUM_N;
const int BLOCK_K_LOOP = K / BLOCK_SIZE_K;
const int8_t* BA = A + BCM * LDA;
const int8_t* BB = B + BCN * LDB;
int* BC = C + BCM * LDC + BCN;
int* BWC = BC + WCM * WARP_SIZE_M * LDC + WCN * WARP_SIZE_N;
constexpr int WARP_M_LOOP = WARP_SIZE_M / TC_SIZE;
constexpr int WARP_N_LOOP = WARP_SIZE_N / TC_SIZE;
constexpr int WARP_K_LOOP = BLOCK_SIZE_K / TC_SIZE;
wmma::fragment<wmma::matrix_a, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::row_major> frag_a[WARP_M_LOOP][WARP_K_LOOP];
wmma::fragment<wmma::matrix_b, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::col_major> frag_b[WARP_K_LOOP][WARP_N_LOOP];
wmma::fragment<wmma::accumulator, TC_SIZE, TC_SIZE, TC_SIZE, int> frag_c[WARP_M_LOOP][WARP_N_LOOP];
#pragma unroll
for (int i = 0; i < WARP_M_LOOP; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_LOOP; j++) {
wmma::fill_fragment(frag_c[i][j], 0);
}
}
constexpr int WARP_SIZE_X = 2;
int lane_id_x = lane_id % (WARP_SIZE_X); // [0,2]
int lane_id_y = lane_id / (WARP_SIZE_X); // [0,16]
for(int k=0; k<BLOCK_K_LOOP; k++){
const auto* load_gmem_addr_a = BA + (warp_id*TC_SIZE + lane_id_y) * LDA + k*BLOCK_SIZE_K + lane_id_x*16;
const auto* load_gmem_addr_b = BB + (warp_id*TC_SIZE + lane_id_y) * LDB + k*BLOCK_SIZE_K + lane_id_x*16;
int store_smem_addr_a = __cvta_generic_to_shared(smem_a[k%2] + (warp_id*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + lane_id_x*16);
int store_smem_addr_b = __cvta_generic_to_shared(smem_b[k%2] + (warp_id*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + lane_id_x*16);
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(store_smem_addr_a), "l"(load_gmem_addr_a), "n"(16));
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(store_smem_addr_b), "l"(load_gmem_addr_b), "n"(16));
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
for(int ki=0; ki<WARP_K_LOOP; ki++)
for(int yi=0; yi<WARP_M_LOOP; yi++){
wmma::load_matrix_sync(frag_a[yi][ki], &smem_a[k%2][(WCM*WARP_SIZE_M+yi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::load_matrix_sync(frag_b[ki][xi], &smem_b[k%2][(WCN*WARP_SIZE_N+xi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
wmma::mma_sync(frag_c[yi][xi], frag_a[yi][ki], frag_b[ki][xi], frag_c[yi][xi]);
}
}
}
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::store_matrix_sync(BWC + (yi*TC_SIZE)*LDC + xi*TC_SIZE, frag_c[yi][xi], LDC, wmma::mem_row_major);
}
}
}
template <int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, int WARP_SIZE_M, int WARP_SIZE_N, int STAGE, bool NoTransA, bool NoTransB, bool RowMajorC>
__global__ void GEMMI8TCU(const int8_t* A, const int8_t* B, int8_t* C, int M, int N, int K)
{
constexpr int WARP_SIZE = 32;
constexpr int TC_SIZE = 16;
constexpr int CP_SIZE_BYTES = 16;
constexpr int WAPR_NUM_N = BLOCK_SIZE_N / WARP_SIZE_N;
constexpr int WAPR_NUM_M = BLOCK_SIZE_M / WARP_SIZE_M;
constexpr int WAPR_NUM = WAPR_NUM_M * WAPR_NUM_N;
static_assert(NoTransA == GEMM_OP_T, "NoTransA == GEMM_OP_T");
static_assert(NoTransB == GEMM_OP_N, "NoTransB == GEMM_OP_N");
static_assert(RowMajorC == GEMM_OP_T, "RowMajorC == GEMM_OP_T");
int warp_id = threadIdx.x/WARP_SIZE;
int lane_id = threadIdx.x%WARP_SIZE;
__shared__ int8_t SLB[STAGE * (BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N)];
int8_t* smem_a[2];
int8_t* smem_b[2];
smem_a[0] = SLB;
smem_a[1] = SLB + BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[0] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[1] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N;
const int BCM = BLOCK_SIZE_M * blockIdx.y;
const int BCN = BLOCK_SIZE_N * blockIdx.x;
const int LDA = NoTransA ? K : M;
const int LDB = NoTransB ? N : K;
const int LDC = RowMajorC ? N : M;
const int WCM = warp_id / WAPR_NUM_N;
const int WCN = warp_id % WAPR_NUM_N;
const int BLOCK_K_LOOP = K / BLOCK_SIZE_K;
const int8_t* BA = A + BCM * LDA;
const int8_t* BB = B + BCN * LDB;
int8_t* BC = C + BCM * LDC + BCN;
int8_t* BWC = BC + WCM * WARP_SIZE_M * LDC + WCN * WARP_SIZE_N;
constexpr int WARP_M_LOOP = WARP_SIZE_M / TC_SIZE;
constexpr int WARP_N_LOOP = WARP_SIZE_N / TC_SIZE;
constexpr int WARP_K_LOOP = BLOCK_SIZE_K / TC_SIZE;
wmma::fragment<wmma::matrix_a, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::row_major> frag_a[WARP_M_LOOP][WARP_K_LOOP];
wmma::fragment<wmma::matrix_b, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::col_major> frag_b[WARP_K_LOOP][WARP_N_LOOP];
wmma::fragment<wmma::accumulator, TC_SIZE, TC_SIZE, TC_SIZE, int> frag_c[WARP_M_LOOP][WARP_N_LOOP];
#pragma unroll
for (int i = 0; i < WARP_M_LOOP; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_LOOP; j++) {
wmma::fill_fragment(frag_c[i][j], 0);
}
}
constexpr int WARP_SIZE_X = 2;
int lane_id_x = lane_id % (WARP_SIZE_X); // [0,2]
int lane_id_y = lane_id / (WARP_SIZE_X); // [0,16]
const int8_t* load_gmem_addr_a, *load_gmem_addr_b;
int store_smem_addr_a, store_smem_addr_b;
int k;
k = 0;
#pragma unroll
for(int j = 0; j < BLOCK_SIZE_K/(CP_SIZE_BYTES*WARP_SIZE_X); j++){
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_M/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_a = BA + (i*TC_SIZE + lane_id_y) * LDA + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_a = __cvta_generic_to_shared(smem_a[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_a), "l"(load_gmem_addr_a), "n"(CP_SIZE_BYTES));
}
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_N/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_b = BB + (i*TC_SIZE + lane_id_y) * LDB + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_b = __cvta_generic_to_shared(smem_b[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_b), "l"(load_gmem_addr_b), "n"(CP_SIZE_BYTES));
}
}
#pragma unroll
for(k=1; k<BLOCK_K_LOOP; k++){
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
#pragma unroll
for(int j = 0; j < BLOCK_SIZE_K/(CP_SIZE_BYTES*WARP_SIZE_X); j++){
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_M/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_a = BA + (i*TC_SIZE + lane_id_y) * LDA + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_a = __cvta_generic_to_shared(smem_a[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_a), "l"(load_gmem_addr_a), "n"(CP_SIZE_BYTES));
}
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_N/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_b = BB + (i*TC_SIZE + lane_id_y) * LDB + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_b = __cvta_generic_to_shared(smem_b[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_b), "l"(load_gmem_addr_b), "n"(CP_SIZE_BYTES));
}
}
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++){
wmma::load_matrix_sync(frag_a[yi][ki], &smem_a[(k-1)%2][(WCM*WARP_SIZE_M+yi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
}
}
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++){
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::load_matrix_sync(frag_b[ki][xi], &smem_b[(k-1)%2][(WCN*WARP_SIZE_N+xi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
}
}
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++)
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::mma_sync(frag_c[yi][xi], frag_a[yi][ki], frag_b[ki][xi], frag_c[yi][xi]);
}
}
}
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
k = BLOCK_K_LOOP -1;
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++){
wmma::load_matrix_sync(frag_a[yi][ki], &smem_a[(k)%2][(WCM*WARP_SIZE_M+yi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
}
}
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++){
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::load_matrix_sync(frag_b[ki][xi], &smem_b[(k)%2][(WCN*WARP_SIZE_N+xi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
}
}
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++)
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::mma_sync(frag_c[yi][xi], frag_a[yi][ki], frag_b[ki][xi], frag_c[yi][xi]);
}
}
int gmem_lane_id_x = lane_id % 4; // [0,4]
int gmem_lane_id_y = lane_id / 4; // [0 8]
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++)
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++)
{
for(int tc_yi=0; tc_yi<2; tc_yi++){
for(int tc_xi=0; tc_xi<2; tc_xi++){
auto* store_gmem_addr = reinterpret_cast<char2*>(BWC + (yi*TC_SIZE + tc_yi*TC_SIZE/2 + gmem_lane_id_y) * LDC + xi*TC_SIZE + tc_xi*TC_SIZE/2 + gmem_lane_id_x*2);
char2 tmp_char2;
tmp_char2.x = static_cast<int8_t>(frag_c[yi][xi].x[tc_xi*4+tc_yi*2+0]);
tmp_char2.y = static_cast<int8_t>(frag_c[yi][xi].x[tc_xi*4+tc_yi*2+1]);
*store_gmem_addr = tmp_char2;
}
}
}
}
}
namespace cutlass_kernel{
template <int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, int WARP_SIZE_M, int WARP_SIZE_N, int STAGE, bool NoTransA, bool NoTransB, bool RowMajorC>
__global__ void GEMMI8TCU(const int8_t* A, const int8_t* B, int8_t* C, int M, int N, int K)
{
constexpr int WARP_SIZE = 32;
constexpr int TC_SIZE = 16;
constexpr int CP_SIZE_BYTES = 16;
constexpr int WAPR_NUM_N = BLOCK_SIZE_N / WARP_SIZE_N;
constexpr int WAPR_NUM_M = BLOCK_SIZE_M / WARP_SIZE_M;
constexpr int WAPR_NUM = WAPR_NUM_M * WAPR_NUM_N;
static_assert(NoTransA == GEMM_OP_T, "NoTransA == GEMM_OP_T");
static_assert(NoTransB == GEMM_OP_N, "NoTransB == GEMM_OP_N");
static_assert(RowMajorC == GEMM_OP_T, "RowMajorC == GEMM_OP_T");
int warp_id = threadIdx.x/WARP_SIZE;
int lane_id = threadIdx.x%WARP_SIZE;
__shared__ int8_t SLB[STAGE * (BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N)];
int8_t* smem_a[2];
int8_t* smem_b[2];
smem_a[0] = SLB;
smem_a[1] = SLB + BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[0] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M;
smem_b[1] = SLB + STAGE*BLOCK_SIZE_K*BLOCK_SIZE_M + BLOCK_SIZE_K*BLOCK_SIZE_N;
const int BCM = BLOCK_SIZE_M * blockIdx.y;
const int BCN = BLOCK_SIZE_N * blockIdx.x;
const int LDA = NoTransA ? K : M;
const int LDB = NoTransB ? N : K;
const int LDC = RowMajorC ? N : M;
const int WCM = warp_id / WAPR_NUM_N;
const int WCN = warp_id % WAPR_NUM_N;
const int BLOCK_K_LOOP = K / BLOCK_SIZE_K;
const int8_t* BA = A + BCM * LDA;
const int8_t* BB = B + BCN * LDB;
int8_t* BC = C + BCM * LDC + BCN;
int8_t* BWC = BC + WCM * WARP_SIZE_M * LDC + WCN * WARP_SIZE_N;
constexpr int WARP_M_LOOP = WARP_SIZE_M / TC_SIZE;
constexpr int WARP_N_LOOP = WARP_SIZE_N / TC_SIZE;
constexpr int WARP_K_LOOP = BLOCK_SIZE_K / TC_SIZE;
wmma::fragment<wmma::matrix_a, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::row_major> frag_a;
wmma::fragment<wmma::matrix_b, TC_SIZE, TC_SIZE, TC_SIZE, int8_t, wmma::col_major> frag_b;
wmma::fragment<wmma::accumulator, TC_SIZE, TC_SIZE, TC_SIZE, int> frag_c[WARP_M_LOOP][WARP_N_LOOP];
#pragma unroll
for (int i = 0; i < WARP_M_LOOP; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_LOOP; j++) {
wmma::fill_fragment(frag_c[i][j], 0);
}
}
constexpr int WARP_SIZE_X = 2;
int lane_id_x = lane_id % (WARP_SIZE_X); // [0,2]
int lane_id_y = lane_id / (WARP_SIZE_X); // [0,16]
const int8_t* load_gmem_addr_a, *load_gmem_addr_b;
int store_smem_addr_a, store_smem_addr_b;
int k;
k = 0;
#pragma unroll
for(int j = 0; j < BLOCK_SIZE_K/(CP_SIZE_BYTES*WARP_SIZE_X); j++){
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_M/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_a = BA + (i*TC_SIZE + lane_id_y) * LDA + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_a = __cvta_generic_to_shared(smem_a[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_a), "l"(load_gmem_addr_a), "n"(CP_SIZE_BYTES));
}
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_N/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_b = BB + (i*TC_SIZE + lane_id_y) * LDB + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_b = __cvta_generic_to_shared(smem_b[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_b), "l"(load_gmem_addr_b), "n"(CP_SIZE_BYTES));
}
}
#pragma unroll
for(k=1; k<BLOCK_K_LOOP; k++){
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
#pragma unroll
for(int j = 0; j < BLOCK_SIZE_K/(CP_SIZE_BYTES*WARP_SIZE_X); j++){
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_M/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_a = BA + (i*TC_SIZE + lane_id_y) * LDA + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_a = __cvta_generic_to_shared(smem_a[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_a), "l"(load_gmem_addr_a), "n"(CP_SIZE_BYTES));
}
#pragma unroll
for(int i=warp_id; i<(BLOCK_SIZE_N/TC_SIZE); i+=WAPR_NUM)
{
load_gmem_addr_b = BB + (i*TC_SIZE + lane_id_y) * LDB + k*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES;
store_smem_addr_b = __cvta_generic_to_shared(smem_b[k%2] + (i*TC_SIZE + lane_id_y)*BLOCK_SIZE_K + j*(CP_SIZE_BYTES*WARP_SIZE_X) + lane_id_x*CP_SIZE_BYTES);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" :: "r"(store_smem_addr_b), "l"(load_gmem_addr_b), "n"(CP_SIZE_BYTES));
}
}
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++)
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
wmma::load_matrix_sync(frag_a, &smem_a[(k-1)%2][(WCM*WARP_SIZE_M+yi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::load_matrix_sync(frag_b, &smem_b[(k-1)%2][(WCN*WARP_SIZE_N+xi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
wmma::mma_sync(frag_c[yi][xi], frag_a, frag_b, frag_c[yi][xi]);
}
}
}
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
k = BLOCK_K_LOOP -1;
#pragma unroll
for(int ki=0; ki<WARP_K_LOOP; ki++)
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++){
wmma::load_matrix_sync(frag_a, &smem_a[(k)%2][(WCM*WARP_SIZE_M+yi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++){
wmma::load_matrix_sync(frag_b, &smem_b[(k)%2][(WCN*WARP_SIZE_N+xi*TC_SIZE)*BLOCK_SIZE_K+ki*TC_SIZE], BLOCK_SIZE_K);
wmma::mma_sync(frag_c[yi][xi], frag_a, frag_b, frag_c[yi][xi]);
}
}
int gmem_lane_id_x = lane_id % 4; // [0,4]
int gmem_lane_id_y = lane_id / 4; // [0 8]
#pragma unroll
for(int yi=0; yi<WARP_M_LOOP; yi++)
#pragma unroll
for(int xi=0; xi<WARP_N_LOOP; xi++)
{
for(int tc_yi=0; tc_yi<2; tc_yi++){
for(int tc_xi=0; tc_xi<2; tc_xi++){
auto* store_gmem_addr = reinterpret_cast<char2*>(BWC + (yi*TC_SIZE + tc_yi*TC_SIZE/2 + gmem_lane_id_y) * LDC + xi*TC_SIZE + tc_xi*TC_SIZE/2 + gmem_lane_id_x*2);
char2 tmp_char2;
tmp_char2.x = static_cast<int8_t>(frag_c[yi][xi].x[tc_xi*4+tc_yi*2+0]);
tmp_char2.y = static_cast<int8_t>(frag_c[yi][xi].x[tc_xi*4+tc_yi*2+1]);
*store_gmem_addr = tmp_char2;
}
}
}
}
}