diff --git a/CMakeLists.txt b/CMakeLists.txt index edc7989..407a88b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,11 @@ if (GGML_CUBLAS) set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES}) endif () +if (GGML_METAL) + add_compile_definitions(GGML_USE_METAL) + configure_file(third_party/ggml/src/ggml-metal.metal ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +endif () + file(GLOB CPP_SOURCES ${PROJECT_SOURCE_DIR}/*.h ${PROJECT_SOURCE_DIR}/*.cpp) diff --git a/README.md b/README.md index 009d09a..7ea7d95 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,13 @@ cuBLAS uses NVIDIA GPU to accelerate BLAS. Add the CMake flag `-DGGML_CUBLAS=ON` cmake -B build -DGGML_CUBLAS=ON && cmake --build build -j ``` +**Metal** + +MPS (Metal Performance Shaders) allows computation to run on powerful Apple Silicon GPU. Add the CMake flag `-DGGML_METAL=ON` to enable it. +```sh +cmake -B build -DGGML_METAL=ON && cmake --build build -j +``` + ## Python Binding The Python binding provides high-level `chat` and `stream_chat` interface similar to the original Hugging Face Qwen-7B. diff --git a/qwen.cpp b/qwen.cpp index 271daa5..910cbf3 100644 --- a/qwen.cpp +++ b/qwen.cpp @@ -61,7 +61,21 @@ auto ggml_graph_compute_helper(std::vector &buf, ggml_cgraph ggml_graph_compute(graph, &plan); } -auto ModelContext::init_device_context() -> void {} +auto ModelContext::init_device_context() -> void { +#ifdef GGML_USE_METAL + ctx_metal = make_unique_ggml_metal_context(1); + const size_t max_size = ggml_get_max_tensor_size(ctx_w.get()); + void *weight_data = weight_buffer.empty() ? ggml_get_mem_buffer(ctx_w.get()) : (void *)weight_buffer.data(); + size_t weight_size = weight_buffer.empty() ? ggml_get_mem_size(ctx_w.get()) : weight_buffer.size(); + QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "weights", weight_data, weight_size, max_size)); + QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "kv", ggml_get_mem_buffer(ctx_kv.get()), + ggml_get_mem_size(ctx_kv.get()), 0)); + void *compute_data = ctx_b ? ggml_get_mem_buffer(ctx_b.get()) : compute_buffer.data(); + size_t compute_size = ctx_b ? ggml_get_mem_size(ctx_b.get()) : compute_buffer.size(); + QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "compute", compute_data, compute_size, 0)); + QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "scratch", scratch.data, scratch.size, 0)); +#endif +} // ===== streamer ===== @@ -482,7 +496,7 @@ auto get_num_physical_cores() -> int { } auto get_default_num_threads() -> int { -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_METAL) return 1; #else return std::min(get_num_physical_cores(), 16); @@ -583,7 +597,11 @@ auto QwenForCausalLM::generate_next_token( } ggml_build_forward_expand(&ctx_.gf, lm_logits); +#ifdef GGML_USE_METAL + ggml_metal_graph_compute(ctx_.ctx_metal.get(), &ctx_.gf); +#else ggml_graph_compute_helper(ctx_.work_buffer, &ctx_.gf, n_threads); +#endif int vocab_size = lm_logits->ne[0]; float *next_token_logits = (float *)lm_logits->data; diff --git a/qwen.h b/qwen.h index 66bb11c..f8d87cc 100644 --- a/qwen.h +++ b/qwen.h @@ -12,6 +12,10 @@ #include #endif +#ifdef GGML_USE_METAL +#include +#endif + namespace qwen { class QwenTokenizer; @@ -58,6 +62,20 @@ static inline auto make_unique_ggml_context( return unique_ggml_context_t(ggml_init({mem_size, mem_buffer, no_alloc})); } +#ifdef GGML_USE_METAL +struct ggml_metal_context_deleter_t { + auto operator()(ggml_metal_context *ctx) const noexcept -> void { ggml_metal_free(ctx); } +}; + +using unique_ggml_metal_context_t = std::unique_ptr; + +static inline auto make_unique_ggml_metal_context( + int n_cb +) -> unique_ggml_metal_context_t { + return unique_ggml_metal_context_t(ggml_metal_init(n_cb)); +} +#endif + struct uninitialized_char { char m; uninitialized_char() {} @@ -70,6 +88,9 @@ struct ModelContext { unique_ggml_context_t ctx_w; // weight unique_ggml_context_t ctx_kv; // kv cache unique_ggml_context_t ctx_b; // buffer +#ifdef GGML_USE_METAL + unique_ggml_metal_context_t ctx_metal; +#endif ggml_cgraph gf; ggml_scratch scratch; std::vector compute_buffer; // BLAS buffer