14
14
#define VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
15
15
16
16
#define TILE_ROWS ${TILE_ROWS}
17
+ #define TILE_TXCOLS ${TILE_TXCOLS}
17
18
18
19
#define NGROUPS 8
19
20
#define NWORKERS 8
@@ -29,7 +30,10 @@ layout(std430) buffer;
29
30
30
31
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
31
32
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array= False)}
32
- ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
33
+ $if QUANT_NBITS == 4 :
34
+ ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array= False)}
35
+ $else :
36
+ ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
33
37
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array= False)}
34
38
35
39
layout (push_constant) uniform restrict Block {
@@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {
42
46
43
47
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
44
48
45
- shared VEC4_T partial_c [NGROUPS][NWORKERS][TILE_ROWS];
49
+ shared VEC4_T partial_sums [NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS ];
46
50
47
51
void main() {
48
- const uint out_width_ntexels = divup4(out_sizes.x);
49
- const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
50
- const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
52
+ // txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
53
+ $if TILE_TXCOLS > 1 :
54
+ const uint global_wg_x = uint (divup(out_sizes.x, 4 * TILE_TXCOLS));
55
+ const uint out_txcol = uint (
56
+ (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
57
+ $else :
58
+ const uint global_wg_x = uint (divup4(out_sizes.x));
59
+ const uint out_txcol = uint (gl_GlobalInvocationID.x % global_wg_x);
60
+
61
+ const uint out_row = uint (
62
+ (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
63
+
64
+ $if QUANT_NBITS == 4 :
65
+ const uint weight_txcol = uint (out_txcol / 2 );
51
66
52
67
const int gid = int (gl_LocalInvocationID.x); // group id
53
68
const int wid = int (gl_LocalInvocationID.z); // worker id
@@ -56,46 +71,78 @@ void main() {
56
71
return ;
57
72
}
58
73
59
- VEC4_T a [TILE_ROWS];
60
- VEC4_T b[ 4 ];
61
- VEC4_T local_c [TILE_ROWS];
74
+ VEC4_T mat1 [TILE_ROWS];
75
+ VEC4_T qmat2[ 4 ][TILE_TXCOLS ];
76
+ VEC4_T local_sums [TILE_ROWS][TILE_TXCOLS ];
62
77
63
- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
64
- local_c[i] = VEC4_T(0.0 );
78
+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
79
+ $for c in range(TILE_TXCOLS):
80
+ local_sums[r][${c}] = VEC4_T(0.0 );
65
81
}
66
82
67
- $if SCALES_STORAGE == "buffer ":
68
- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
69
- $else :
70
- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
71
-
72
- for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
73
- // Preload t_weight
74
- [[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
75
- $if WEIGHT_STORAGE == "buffer ":
76
- b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2 ];
83
+ VEC4_T scales[TILE_TXCOLS];
84
+ $for c in range(TILE_TXCOLS):
85
+ $if SCALES_STORAGE == "buffer ":
86
+ scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
87
+ $else :
88
+ scales[${c}] = VEC4_T(
89
+ texelFetch(t_scales, ivec2 (out_txcol + ${c}, 0 ), 0 ));
90
+
91
+ for (int pos = (4 * wid), txpos = wid;
92
+ pos < in_sizes.x;
93
+ pos += (4 * NWORKERS), txpos += NWORKERS) {
94
+ $if WEIGHT_STORAGE == "buffer ":
95
+ uint qmat2_bufi;
96
+ uint weight_row_txstride = div4(weight_sizes.x);
97
+
98
+ // Preload weight tensor
99
+ [[unroll]] for (int r = 0 ; r < 4 ; r++ ) {
100
+ $if QUANT_NBITS == 4 :
101
+ $for c in range(0 , TILE_TXCOLS, 2 ):
102
+ $if WEIGHT_STORAGE == "buffer ":
103
+ qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
104
+ const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
105
+ $else :
106
+ const uvec4 packed_weight_tex = texelFetch(
107
+ t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
108
+
109
+ qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4 ) - 8.0 );
110
+ qmat2[r][${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
77
111
$else :
78
- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
112
+ $for c in range(TILE_TXCOLS):
113
+ $if WEIGHT_STORAGE == "buffer ":
114
+ qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
115
+ qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
116
+ $else :
117
+ qmat2[r][${c}] = VEC4_T(
118
+ texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
79
119
}
80
- // Preload t_in
81
- for (int i = 0 ; i < TILE_ROWS; i++ ) {
120
+
121
+ $if IN_STORAGE == "buffer ":
122
+ uint in_row_txstride = div4(in_sizes.x);
123
+
124
+ // Preload input tensor
125
+ [[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
82
126
$if IN_STORAGE == "buffer ":
83
- a [i] = t_in[(( out_row + i) * in_sizes.x + pos) >> 2 ];
127
+ mat1 [i] = t_in[(out_row + i) * in_row_txstride + txpos ];
84
128
$else :
85
- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
129
+ mat1[i] = VEC4_T(
130
+ texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
86
131
}
87
132
88
133
// Accumulate partial output
89
- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
90
- local_c[i] += a[i].x * b[0 ] +
91
- a[i].y * b[1 ] +
92
- a[i].z * b[2 ] +
93
- a[i].w * b[3 ];
134
+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
135
+ $for c in range(TILE_TXCOLS):
136
+ local_sums[r][${c}] += mat1[r].x * qmat2[0 ][${c}] +
137
+ mat1[r].y * qmat2[1 ][${c}] +
138
+ mat1[r].z * qmat2[2 ][${c}] +
139
+ mat1[r].w * qmat2[3 ][${c}];
94
140
}
95
141
}
96
142
97
- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
98
- partial_c[gid][wid][i] = local_c[i];
143
+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
144
+ $for c in range(TILE_TXCOLS):
145
+ partial_sums[gid][wid][r][${c}] = local_sums[r][${c}];
99
146
}
100
147
101
148
memoryBarrierShared();
@@ -105,21 +152,33 @@ void main() {
105
152
return ;
106
153
}
107
154
108
- VEC4_T c[TILE_ROWS];
155
+ VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
156
+
157
+ for (int r = 0 ; r < TILE_ROWS; ++ r) {
158
+ $for c in range(TILE_TXCOLS):
159
+ sums[r][${c}] = VEC4_T(0.0 );
109
160
110
- for (int row = 0 ; row < TILE_ROWS; ++ row) {
111
- c[row] = VEC4_T(0.0 );
112
161
[[unroll]] for (int worker = 0 ; worker < NWORKERS; ++ worker) {
113
- c[row] += partial_c[gid][worker][row];
162
+ $for c in range(TILE_TXCOLS):
163
+ sums[r][${c}] += partial_sums[gid][worker][r][${c}];
114
164
}
115
165
}
116
166
117
- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
118
- $if OUT_STORAGE == "buffer ":
119
- if (out_row + i < out_sizes.y) {
120
- t_out[((out_row + i) * out_sizes.x + out_col) >> 2 ] = c[i] * scales;
121
- }
122
- $else :
123
- imageStore(t_out, ivec3 (out_col >> 2 , out_row + i, 0 ), c[i] * scales);
167
+ $if OUT_STORAGE == "buffer ":
168
+ uint out_bufi;
169
+ uint out_row_txstride = div4(out_sizes.x);
170
+
171
+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
172
+ $for c in range(TILE_TXCOLS):
173
+ $if OUT_STORAGE == "buffer ":
174
+ if (out_row + r < out_sizes.y) {
175
+ out_bufi = (out_row + r) * out_row_txstride + out_txcol;
176
+ t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
177
+ }
178
+ $else :
179
+ imageStore(
180
+ t_out,
181
+ ivec3 (out_txcol + ${c}, out_row + r, 0 ),
182
+ sums[r][${c}] * scales[${c}]);
124
183
}
125
184
}
0 commit comments