Skip to content

Commit 13b87f2

Browse files
committed
metal : fix support check
ggml-ci
1 parent e9565cc commit 13b87f2

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ggml/src/ggml-metal.m

+5-1
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
949949
case GGML_OP_LEAKY_RELU:
950950
return true;
951951
case GGML_OP_FLASH_ATTN_EXT:
952+
if (op->src[1]->type != op->src[2]->type) {
953+
return false;
954+
}
952955
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
953956
case GGML_OP_SSM_CONV:
954957
case GGML_OP_SSM_SCAN:
@@ -2893,6 +2896,7 @@ static void ggml_metal_encode_node(
28932896
GGML_ASSERT(ne11 % 32 == 0);
28942897

28952898
GGML_ASSERT(src0->type == GGML_TYPE_F32);
2899+
GGML_ASSERT(src1->type == src2->type);
28962900

28972901
GGML_ASSERT(ggml_are_same_shape (src1, src2));
28982902

@@ -3165,7 +3169,7 @@ static void ggml_metal_encode_node(
31653169

31663170
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
31673171
} else {
3168-
// half1x4 kernel
3172+
// half4x4 kernel
31693173
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
31703174
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
31713175

0 commit comments

Comments
 (0)