@@ -2347,11 +2347,12 @@ struct test_soft_max : public test_case {
2347
2347
const ggml_type type;
2348
2348
const std::array<int64_t , 4 > ne;
2349
2349
const bool mask;
2350
+ const ggml_type m_prec;
2350
2351
const float scale;
2351
2352
const float max_bias;
2352
2353
2353
2354
std::string vars () override {
2354
- return VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2355
+ return VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
2355
2356
}
2356
2357
2357
2358
// the 1024 test with bias occasionally fails:
@@ -2363,9 +2364,10 @@ struct test_soft_max : public test_case {
2363
2364
test_soft_max (ggml_type type = GGML_TYPE_F32,
2364
2365
std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
2365
2366
bool mask = false ,
2367
+ ggml_type m_prec = GGML_TYPE_F32,
2366
2368
float scale = 1 .0f ,
2367
2369
float max_bias = 0 .0f )
2368
- : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2370
+ : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
2369
2371
2370
2372
ggml_tensor * build_graph (ggml_context * ctx) override {
2371
2373
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2374,7 +2376,7 @@ struct test_soft_max : public test_case {
2374
2376
2375
2377
ggml_tensor * mask = nullptr ;
2376
2378
if (this ->mask ) {
2377
- mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2379
+ mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
2378
2380
ggml_set_name (mask, " mask" );
2379
2381
}
2380
2382
@@ -4150,17 +4152,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4150
4152
for (float scale : {1 .0f , 0 .1f }) {
4151
4153
for (int64_t ne0 : {16 , 1024 }) {
4152
4154
for (int64_t ne1 : {16 , 1024 }) {
4153
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, scale, max_bias));
4154
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4155
+ if (mask) {
4156
+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4157
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, scale, max_bias));
4158
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4159
+ }
4160
+ } else {
4161
+ /* The precision of mask here doesn't matter as boolean mask is false */
4162
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4163
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4164
+ }
4155
4165
}
4156
4166
}
4157
4167
}
4158
4168
}
4159
4169
}
4160
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4161
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4162
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4163
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
4170
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4171
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4172
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4173
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4174
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4175
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 8 .0f ));
4176
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 8 .0f ));
4164
4177
4165
4178
for (float max_bias : {0 .0f , 8 .0f }) {
4166
4179
for (float scale : {1 .0f , 0 .1f }) {
@@ -4296,13 +4309,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4296
4309
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4297
4310
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4298
4311
4299
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4300
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4301
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4302
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4303
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4304
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4305
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4312
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4313
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4314
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4315
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4316
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4317
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4318
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4306
4319
4307
4320
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
4308
4321
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments