Skip to content

Commit b7225ec

Browse files
committed
Review: add type traits and make function more generic
1 parent 789c697 commit b7225ec

File tree

3 files changed

+134
-207
lines changed

3 files changed

+134
-207
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,14 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
739739
return nullptr;
740740
}
741741
}
742+
743+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744+
switch (type) {
745+
case GGML_TYPE_F16:
746+
return convert_unary_cuda<half, float>;
747+
case GGML_TYPE_BF16:
748+
return convert_unary_cuda<nv_bfloat16, float>;
749+
default:
750+
return nullptr;
751+
}
752+
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
2222
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
2323
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
2424

25+
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
2526
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
2627
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28+
29+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
2730
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
2831
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

0 commit comments

Comments
 (0)