Skip to content

Commit 5e8295e

Browse files
authored
[ET-VK] Implement linear_qcs4w (#10772)
## Context Title says it all! ## Changes Extended the implementation of `linear_qcsnw` to support packed 4-bit weight tensors. Differential Revision: [D73941991](https://our.internmc.facebook.com/intern/diff/D73941991/)
1 parent b1d00e2 commit 5e8295e

11 files changed

+528
-150
lines changed

backends/vulkan/op_registry.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,12 @@ def register_mm_op(features: OpFeatures):
377377
return features
378378

379379

380-
@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
380+
@update_features(
381+
[
382+
exir_ops.edge.aten._weight_int8pack_mm.default,
383+
exir_ops.edge.et_vk.linear_qcs4w.default,
384+
]
385+
)
381386
def register_int8_mm_op(features: OpFeatures):
382387
features.texture_impl = TextureImplFeatures(
383388
uses_axis_map=False,

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

+14-4
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,32 @@
4141
/*
4242
* Fast division by 4 using bit shifting
4343
*/
44-
#define div4(x) (x >> 2)
44+
#define div4(x) ((x) >> 2)
45+
46+
/*
47+
* Fast multiplication by 4 using bit shifting
48+
*/
49+
#define mul4(x) ((x) << 2)
4550

4651
/*
4752
* Divides input and rounds up to 4
4853
*/
49-
#define divup4(x) ((x + 3) >> 2)
54+
#define divup4(x) (((x) + 3) >> 2)
55+
56+
/*
57+
* Divides input by denominator and rounds up
58+
*/
59+
#define divup(x, d) (((x) + (d) - 1) / (d))
5060

5161
/*
5262
* Aligns input to the next multiple of 4
5363
*/
54-
#define alignup4(x) ((x + 3) & -4)
64+
#define alignup4(x) (((x) + 3) & -4)
5565

5666
/*
5767
* Fast modulo by 4 using bit masking
5868
*/
59-
#define mod4(x) (x & 3)
69+
#define mod4(x) ((x) & 3)
6070

6171
/*
6272
* Find the packed dimension of a tensor given its strides. The packed dimension

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl

+102-43
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
1515

1616
#define TILE_ROWS ${TILE_ROWS}
17+
#define TILE_TXCOLS ${TILE_TXCOLS}
1718

1819
#define NGROUPS 8
1920
#define NWORKERS 8
@@ -29,7 +30,10 @@ layout(std430) buffer;
2930

3031
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
3132
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
32-
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
33+
$if QUANT_NBITS == 4:
34+
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
35+
$else:
36+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
3337
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
3438

3539
layout(push_constant) uniform restrict Block {
@@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {
4246

4347
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4448

45-
shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
49+
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS];
4650

4751
void main() {
48-
const uint out_width_ntexels = divup4(out_sizes.x);
49-
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
50-
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
52+
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
53+
$if TILE_TXCOLS > 1:
54+
const uint global_wg_x = uint(divup(out_sizes.x, 4 * TILE_TXCOLS));
55+
const uint out_txcol = uint(
56+
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
57+
$else:
58+
const uint global_wg_x = uint(divup4(out_sizes.x));
59+
const uint out_txcol = uint(gl_GlobalInvocationID.x % global_wg_x);
60+
61+
const uint out_row = uint(
62+
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
63+
64+
$if QUANT_NBITS == 4:
65+
const uint weight_txcol = uint(out_txcol / 2);
5166

5267
const int gid = int(gl_LocalInvocationID.x); // group id
5368
const int wid = int(gl_LocalInvocationID.z); // worker id
@@ -56,46 +71,78 @@ void main() {
5671
return;
5772
}
5873

59-
VEC4_T a[TILE_ROWS];
60-
VEC4_T b[4];
61-
VEC4_T local_c[TILE_ROWS];
74+
VEC4_T mat1[TILE_ROWS];
75+
VEC4_T qmat2[4][TILE_TXCOLS];
76+
VEC4_T local_sums[TILE_ROWS][TILE_TXCOLS];
6277

63-
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
64-
local_c[i] = VEC4_T(0.0);
78+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
79+
$for c in range(TILE_TXCOLS):
80+
local_sums[r][${c}] = VEC4_T(0.0);
6581
}
6682

67-
$if SCALES_STORAGE == "buffer":
68-
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
69-
$else:
70-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
71-
72-
for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
73-
// Preload t_weight
74-
[[unroll]] for (int i = 0; i < 4; i++) {
75-
$if WEIGHT_STORAGE == "buffer":
76-
b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2];
83+
VEC4_T scales[TILE_TXCOLS];
84+
$for c in range(TILE_TXCOLS):
85+
$if SCALES_STORAGE == "buffer":
86+
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
87+
$else:
88+
scales[${c}] = VEC4_T(
89+
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));
90+
91+
for (int pos = (4 * wid), txpos = wid;
92+
pos < in_sizes.x;
93+
pos += (4 * NWORKERS), txpos += NWORKERS) {
94+
$if WEIGHT_STORAGE == "buffer":
95+
uint qmat2_bufi;
96+
uint weight_row_txstride = div4(weight_sizes.x);
97+
98+
// Preload weight tensor
99+
[[unroll]] for (int r = 0; r < 4; r++) {
100+
$if QUANT_NBITS == 4:
101+
$for c in range(0, TILE_TXCOLS, 2):
102+
$if WEIGHT_STORAGE == "buffer":
103+
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
104+
const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
105+
$else:
106+
const uvec4 packed_weight_tex = texelFetch(
107+
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
108+
109+
qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);
110+
qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
77111
$else:
78-
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
112+
$for c in range(TILE_TXCOLS):
113+
$if WEIGHT_STORAGE == "buffer":
114+
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
115+
qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
116+
$else:
117+
qmat2[r][${c}] = VEC4_T(
118+
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
79119
}
80-
// Preload t_in
81-
for (int i = 0; i < TILE_ROWS; i++) {
120+
121+
$if IN_STORAGE == "buffer":
122+
uint in_row_txstride = div4(in_sizes.x);
123+
124+
// Preload input tensor
125+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
82126
$if IN_STORAGE == "buffer":
83-
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
127+
mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos];
84128
$else:
85-
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
129+
mat1[i] = VEC4_T(
130+
texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
86131
}
87132

88133
// Accumulate partial output
89-
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
90-
local_c[i] += a[i].x * b[0] +
91-
a[i].y * b[1] +
92-
a[i].z * b[2] +
93-
a[i].w * b[3];
134+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
135+
$for c in range(TILE_TXCOLS):
136+
local_sums[r][${c}] += mat1[r].x * qmat2[0][${c}] +
137+
mat1[r].y * qmat2[1][${c}] +
138+
mat1[r].z * qmat2[2][${c}] +
139+
mat1[r].w * qmat2[3][${c}];
94140
}
95141
}
96142

97-
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
98-
partial_c[gid][wid][i] = local_c[i];
143+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
144+
$for c in range(TILE_TXCOLS):
145+
partial_sums[gid][wid][r][${c}] = local_sums[r][${c}];
99146
}
100147

101148
memoryBarrierShared();
@@ -105,21 +152,33 @@ void main() {
105152
return;
106153
}
107154

108-
VEC4_T c[TILE_ROWS];
155+
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
156+
157+
for (int r = 0; r < TILE_ROWS; ++r) {
158+
$for c in range(TILE_TXCOLS):
159+
sums[r][${c}] = VEC4_T(0.0);
109160

110-
for (int row = 0; row < TILE_ROWS; ++row) {
111-
c[row] = VEC4_T(0.0);
112161
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
113-
c[row] += partial_c[gid][worker][row];
162+
$for c in range(TILE_TXCOLS):
163+
sums[r][${c}] += partial_sums[gid][worker][r][${c}];
114164
}
115165
}
116166

117-
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
118-
$if OUT_STORAGE == "buffer":
119-
if (out_row + i < out_sizes.y) {
120-
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
121-
}
122-
$else:
123-
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
167+
$if OUT_STORAGE == "buffer":
168+
uint out_bufi;
169+
uint out_row_txstride = div4(out_sizes.x);
170+
171+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
172+
$for c in range(TILE_TXCOLS):
173+
$if OUT_STORAGE == "buffer":
174+
if (out_row + r < out_sizes.y) {
175+
out_bufi = (out_row + r) * out_row_txstride + out_txcol;
176+
t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
177+
}
178+
$else:
179+
imageStore(
180+
t_out,
181+
ivec3(out_txcol + ${c}, out_row + r, 0),
182+
sums[r][${c}] * scales[${c}]);
124183
}
125184
}

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ linear_qcsnw_coop:
1212
WEIGHT_STORAGE: texture2d
1313
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
15+
TILE_TXCOLS: 1
16+
QUANT_NBITS: 8
1517
generate_variant_forall:
1618
TILE_ROWS:
1719
- VALUE: 1
@@ -26,3 +28,11 @@ linear_qcsnw_coop:
2628
OUT_STORAGE: buffer
2729
WEIGHT_STORAGE: buffer
2830
SCALES_STORAGE: buffer
31+
- NAME: linear_qcs4w_coop_texture3d_texture3d_texture2d_texture2d_float
32+
TILE_TXCOLS: 2
33+
QUANT_NBITS: 4
34+
- NAME: linear_qcs4w_coop_buffer_buffer_texture2d_texture2d_float
35+
IN_STORAGE: buffer
36+
OUT_STORAGE: buffer
37+
TILE_TXCOLS: 2
38+
QUANT_NBITS: 4

0 commit comments

Comments
 (0)