@@ -147,9 +147,41 @@ inline static void* ggml_aligned_malloc(size_t size) {
147
147
#include <Accelerate/Accelerate.h>
148
148
#elif defined(GGML_USE_OPENBLAS )
149
149
#include <cblas.h>
150
- #elif defined(GGML_USE_CUBLAS )
150
+ #elif defined(GGML_USE_CUBLAS ) || defined(GGML_USE_HIPBLAS )
151
+
152
+ #if defined(GGML_USE_HIPBLAS )
153
+ #include "hipblas/hipblas.h"
154
+ #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
155
+ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
156
+ #define CUBLAS_OP_N HIPBLAS_OP_N
157
+ #define CUBLAS_OP_T HIPBLAS_OP_T
158
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
159
+ #define cublasCreate hipblasCreate
160
+ #define cublasGemmEx hipblasGemmEx
161
+ #define cublasHandle_t hipblasHandle_t
162
+ #define cublasSetStream hipblasSetStream
163
+ #define cublasSgemm hipblasSgemm
164
+ #define cublasStatus_t hipblasStatus_t
165
+ #define CUDA_R_16F HIPBLAS_R_16F
166
+ #define CUDA_R_32F HIPBLAS_R_32F
167
+ #define cudaError_t hipError_t
168
+ #define cudaFree hipFree
169
+ #define cudaGetErrorString hipGetErrorString
170
+ #define cudaGetLastError hipGetLastError
171
+ #define cudaMalloc hipMalloc
172
+ #define cudaMemcpyAsync hipMemcpyAsync
173
+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
174
+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
175
+ #define cudaStream_t hipStream_t
176
+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
177
+ #define cudaStreamNonBlocking hipStreamNonBlocking
178
+ #define cudaStreamSynchronize hipStreamSynchronize
179
+ #define cudaSuccess hipSuccess
180
+ #define GGML_USE_CUBLAS
181
+ #else
151
182
#include <cublas_v2.h>
152
183
#include <cuda_runtime.h>
184
+ #endif
153
185
#include "ggml-cuda.h"
154
186
155
187
#define CUDA_CHECK (err ) \
@@ -8073,7 +8105,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
8073
8105
const float * x = wdata ;
8074
8106
#endif
8075
8107
8076
-
8077
8108
#if defined(GGML_USE_CUBLAS )
8078
8109
// copy data to device
8079
8110
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
0 commit comments