@@ -6318,18 +6318,19 @@ kernel void kernel_mul_mm(device const uchar * src0,
6318
6318
const uint im = tgpig.z ;
6319
6319
6320
6320
// if this block is of 64x32 shape or smaller
6321
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
6322
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
6321
+ short n_rows = (ne0 - r0* BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0* BLOCK_SIZE_M) : BLOCK_SIZE_M;
6322
+ short n_cols = (ne1 - r1* BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1* BLOCK_SIZE_N) : BLOCK_SIZE_N;
6323
6323
6324
6324
// a thread shouldn't load data outside of the matrix
6325
6325
short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
6326
6326
short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
6327
6327
6328
6328
simdgroup_T8x8 ma[4 ];
6329
6329
simdgroup_float8x8 mb[2 ];
6330
- simdgroup_float8x8 c_res[8 ];
6331
- for (int i = 0 ; i < 8 ; i++){
6332
- c_res[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6330
+ simdgroup_float8x8 mc[8 ];
6331
+
6332
+ for (short i = 0 ; i < 8 ; i++){
6333
+ mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6333
6334
}
6334
6335
6335
6336
short il = (tiitg % THREAD_PER_ROW);
@@ -6340,7 +6341,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6340
6341
uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
6341
6342
ushort offset1 = il/nl;
6342
6343
6343
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
6344
+ device const block_q * x = (device const block_q *)(src0 + (r0* BLOCK_SIZE_M + thread_row)* nb01 + offset0) + offset1;
6344
6345
device const float * y = (device const float *)(src1
6345
6346
+ nb13 * i13
6346
6347
+ nb12 * i12
@@ -6354,13 +6355,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
6354
6355
threadgroup_barrier (mem_flags::mem_threadgroup);
6355
6356
6356
6357
#pragma unroll(16)
6357
- for (int i = 0 ; i < 16 ; i++) {
6358
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
6359
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8 ) * 8 ) \
6360
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
6358
+ for (short i = 0 ; i < 16 ; i++) {
6359
+ *(sa + SG_MAT_SIZE * ((tiitg/ THREAD_PER_ROW/ 8 ) \
6360
+ + (tiitg% THREAD_PER_ROW)* 16 + (i/ 8 )* 8 ) \
6361
+ + (tiitg/ THREAD_PER_ROW)% 8 + (i& 7 )* 8 ) = temp_a[i/4 ][i%4 ];
6361
6362
}
6362
6363
6363
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6364
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)* 8 * 32 + 8 * (tiitg/ THREAD_PER_COL)) = *((device float2x4 *) y);
6364
6365
6365
6366
il = (il + 2 < nl) ? il + 2 : il % 2 ;
6366
6367
x = (il < 2 ) ? x + (2 +nl-1 )/nl : x;
@@ -6369,53 +6370,64 @@ kernel void kernel_mul_mm(device const uchar * src0,
6369
6370
threadgroup_barrier (mem_flags::mem_threadgroup);
6370
6371
6371
6372
// load matrices from threadgroup memory and conduct outer products
6372
- threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6373
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6373
+ threadgroup T * lsma = (sa + THREAD_MAT_M* SG_MAT_SIZE* (sgitg% 2 ));
6374
+ threadgroup float * lsmb = (sb + THREAD_MAT_N* SG_MAT_SIZE* (sgitg/ 2 ));
6374
6375
6375
6376
#pragma unroll(4)
6376
- for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6377
+ for (short ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6377
6378
#pragma unroll(4)
6378
- for (int i = 0 ; i < 4 ; i++) {
6379
- simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
6379
+ for (short i = 0 ; i < 4 ; i++) {
6380
+ simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
6380
6381
}
6381
6382
simdgroup_barrier (mem_flags::mem_none);
6382
6383
#pragma unroll(2)
6383
- for (int i = 0 ; i < 2 ; i++) {
6384
- simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
6384
+ for (short i = 0 ; i < 2 ; i++) {
6385
+ simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
6385
6386
}
6386
6387
6387
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6388
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6388
+ lsma += BLOCK_SIZE_M/ SG_MAT_ROW * SG_MAT_SIZE;
6389
+ lsmb += BLOCK_SIZE_N/ SG_MAT_ROW * SG_MAT_SIZE;
6389
6390
6390
6391
#pragma unroll(8)
6391
- for (int i = 0 ; i < 8 ; i++){
6392
- simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
6392
+ for (short i = 0 ; i < 8 ; i++){
6393
+ simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
6393
6394
}
6394
6395
}
6395
6396
}
6396
6397
6397
6398
if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
6398
6399
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
6399
6400
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6400
- for (int i = 0 ; i < 8 ; i++) {
6401
- simdgroup_store (c_res [i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6401
+ for (short i = 0 ; i < 8 ; i++) {
6402
+ simdgroup_store (mc [i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6402
6403
}
6403
6404
} else {
6404
6405
// block is smaller than 64x32, we should avoid writing data outside of the matrix
6405
6406
threadgroup_barrier (mem_flags::mem_threadgroup);
6406
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6407
- + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6408
- for (int i = 0 ; i < 8 ; i++) {
6409
- simdgroup_store (c_res [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6407
+ threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
6408
+ + 32 * (sgitg&1 ) + (16 * (sgitg>>1 ))* BLOCK_SIZE_M;
6409
+ for (short i = 0 ; i < 8 ; i++) {
6410
+ simdgroup_store (mc [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M* (i/4 ), BLOCK_SIZE_M);
6410
6411
}
6411
6412
6412
6413
threadgroup_barrier (mem_flags::mem_threadgroup);
6413
6414
6414
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
6415
6415
if (sgitg == 0 ) {
6416
- for (int i = 0 ; i < n_rows; i++) {
6417
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6418
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6416
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6417
+ device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
6418
+ device float4 * D4 = (device float4 *) D;
6419
+
6420
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6421
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
6422
+
6423
+ int i = 0 ;
6424
+ for (; i < n_rows/4 ; i++) {
6425
+ *(D4 + i) = *(C4 + i);
6426
+ }
6427
+
6428
+ i *= 4 ;
6429
+ for (; i < n_rows; i++) {
6430
+ *(D + i) = *(C + i);
6419
6431
}
6420
6432
}
6421
6433
}
0 commit comments