@@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
214
214
static_assert (K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2 , " K_QUANTS_PER_ITERATION must be 1 or 2" );
215
215
#endif
216
216
217
+ struct ggml_tensor_extra_gpu {
218
+ void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
219
+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
220
+ };
221
+
217
222
static __global__ void add_f32 (const float * x, const float * y, float * dst, const int k) {
218
223
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
219
224
@@ -1995,7 +2000,6 @@ inline void ggml_cuda_op_add(
1995
2000
} else {
1996
2001
GGML_ASSERT (false );
1997
2002
}
1998
- CUDA_CHECK (cudaGetLastError ());
1999
2003
2000
2004
(void ) src1;
2001
2005
(void ) dst;
@@ -2027,7 +2031,6 @@ inline void ggml_cuda_op_mul(
2027
2031
2028
2032
// compute
2029
2033
mul_f32_cuda (src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2030
- CUDA_CHECK (cudaGetLastError ());
2031
2034
}
2032
2035
2033
2036
(void ) dst;
@@ -2048,7 +2051,6 @@ inline void ggml_cuda_op_silu(
2048
2051
2049
2052
// compute
2050
2053
silu_f32_cuda (src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
2051
- CUDA_CHECK (cudaGetLastError ());
2052
2054
2053
2055
(void ) src1;
2054
2056
(void ) dst;
@@ -2071,7 +2073,6 @@ inline void ggml_cuda_op_rms_norm(
2071
2073
2072
2074
// compute
2073
2075
rms_norm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2074
- CUDA_CHECK (cudaGetLastError ());
2075
2076
2076
2077
(void ) src1;
2077
2078
(void ) dst;
@@ -2150,7 +2151,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
2150
2151
GGML_ASSERT (false );
2151
2152
break ;
2152
2153
}
2153
- CUDA_CHECK (cudaGetLastError ());
2154
2154
2155
2155
#ifdef GGML_CUDA_DMMV_F16
2156
2156
if (src1_convert_f16) {
@@ -2230,7 +2230,6 @@ inline void ggml_cuda_op_rope(
2230
2230
2231
2231
// compute
2232
2232
rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
2233
- CUDA_CHECK (cudaGetLastError ());
2234
2233
2235
2234
(void ) dst;
2236
2235
(void ) src0_ddq_i;
@@ -2254,7 +2253,6 @@ inline void ggml_cuda_op_diag_mask_inf(
2254
2253
2255
2254
// compute
2256
2255
diag_mask_inf_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
2257
- CUDA_CHECK (cudaGetLastError ());
2258
2256
2259
2257
(void ) dst;
2260
2258
(void ) src0_ddq_i;
@@ -2276,7 +2274,6 @@ inline void ggml_cuda_op_soft_max(
2276
2274
2277
2275
// compute
2278
2276
soft_max_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2279
- CUDA_CHECK (cudaGetLastError ());
2280
2277
2281
2278
(void ) src1;
2282
2279
(void ) dst;
@@ -2372,10 +2369,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2372
2369
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0 };
2373
2370
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0 };
2374
2371
2375
- // if multiple GPUs are used they need to wait for the main GPU to finish
2372
+ // if multiple devices are used they need to wait for the main device
2373
+ // here an event is recorded that signifies that the main device has finished calculating the input data
2376
2374
if (split && g_device_count > 1 ) {
2377
2375
CUDA_CHECK (cudaSetDevice (g_main_device));
2378
- CUDA_CHECK (cudaDeviceSynchronize ( ));
2376
+ CUDA_CHECK (cudaEventRecord (src0_extra-> events [g_main_device], g_cudaStreams_main[g_main_device] ));
2379
2377
}
2380
2378
2381
2379
for (int id = 0 ; id < g_device_count; ++id) {
@@ -2401,6 +2399,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2401
2399
int64_t row_diff = row_high - row_low;
2402
2400
2403
2401
cudaSetDevice (id);
2402
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2403
+
2404
+ // wait for main GPU data if necessary
2405
+ if (split && id != g_main_device) {
2406
+ CUDA_CHECK (cudaStreamWaitEvent (cudaStream_main, src0_extra->events [g_main_device]));
2407
+ }
2404
2408
2405
2409
if (src0_on_device && src0_is_contiguous) {
2406
2410
if (src0_is_f32) {
@@ -2476,8 +2480,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2476
2480
}
2477
2481
const int64_t i11 = i13*ne12 + i12;
2478
2482
2479
- cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2480
-
2481
2483
// for split tensors the data begins at i0 == i0_offset_low
2482
2484
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
2483
2485
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2537,6 +2539,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2537
2539
2538
2540
// do the computation
2539
2541
op (src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
2542
+ CUDA_CHECK (cudaGetLastError ());
2540
2543
2541
2544
// copy dst to host or other device if necessary
2542
2545
if (!dst_on_device) {
@@ -2566,6 +2569,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2566
2569
CUDA_CHECK (cudaMemcpyAsync (dhf_dst_i, dst_ddf_i, dst_stride*sizeof (float ), kind, cudaStream_main));
2567
2570
}
2568
2571
}
2572
+
2573
+ // signify to main device that other device is done
2574
+ if (split && g_device_count > 1 && id != g_main_device) {
2575
+ CUDA_CHECK (cudaEventRecord (src0_extra->events [id], cudaStream_main));
2576
+ }
2569
2577
}
2570
2578
}
2571
2579
}
@@ -2577,7 +2585,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2577
2585
}
2578
2586
2579
2587
CUDA_CHECK (cudaSetDevice (id));
2580
- CUDA_CHECK (cudaDeviceSynchronize ());
2581
2588
2582
2589
if (src0_asq[id] > 0 ) {
2583
2590
ggml_cuda_pool_free (src0_ddq[id], src0_asq[id]);
@@ -2592,6 +2599,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2592
2599
ggml_cuda_pool_free (dst_ddf[id], dst_asf[id]);
2593
2600
}
2594
2601
}
2602
+
2603
+ // main device waits for all other devices to be finished
2604
+ if (split && g_device_count > 1 ) {
2605
+ CUDA_CHECK (cudaSetDevice (g_main_device));
2606
+ for (int id = 0 ; id < g_device_count; ++id) {
2607
+ if (id != g_main_device) {
2608
+ CUDA_CHECK (cudaStreamWaitEvent (g_cudaStreams_main[g_main_device], src0_extra->events [id]));
2609
+ }
2610
+ }
2611
+ }
2612
+
2613
+ if (dst->backend == GGML_BACKEND_CPU) {
2614
+ CUDA_CHECK (cudaSetDevice (g_main_device));
2615
+ CUDA_CHECK (cudaDeviceSynchronize ());
2616
+ }
2595
2617
}
2596
2618
2597
2619
void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2831,6 +2853,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
2831
2853
cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice);
2832
2854
2833
2855
extra->data_device [id] = buf;
2856
+
2857
+ if (backend == GGML_BACKEND_GPU_SPLIT) {
2858
+ CUDA_CHECK (cudaEventCreateWithFlags (&extra->events [id], cudaEventDisableTiming));
2859
+ }
2834
2860
}
2835
2861
2836
2862
tensor->extra = extra;
@@ -2844,12 +2870,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
2844
2870
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra ;
2845
2871
2846
2872
for (int id = 0 ; id < g_device_count; ++id) {
2847
- if (extra->data_device [id] == nullptr ) {
2848
- continue ;
2873
+ if (extra->data_device [id] != nullptr ) {
2874
+ CUDA_CHECK (cudaSetDevice (id));
2875
+ CUDA_CHECK (cudaFree (extra->data_device [id]));
2849
2876
}
2850
2877
2851
- CUDA_CHECK (cudaSetDevice (id));
2852
- CUDA_CHECK (cudaFree (extra->data_device [id]));
2878
+ if (extra->events [id] != nullptr ) {
2879
+ CUDA_CHECK (cudaSetDevice (id));
2880
+ CUDA_CHECK (cudaEventDestroy (extra->events [id]));
2881
+ }
2853
2882
}
2854
2883
2855
2884
delete extra;
0 commit comments