@@ -67,6 +67,14 @@ static void atomic_store(atomic_int * ptr, LONG val) {
67
67
static LONG atomic_load(atomic_int * ptr) {
68
68
return InterlockedCompareExchange(ptr, 0, 0);
69
69
}
70
+
71
+ static LONG atomic_compare_exchange_weak(atomic_int * ptr, atomic_int* pcomparand, LONG exchange)
72
+ {
73
+ LONG comparand = *pcomparand;
74
+ LONG ret = InterlockedCompareExchange(ptr, exchange, comparand);
75
+ return (ret == comparand);
76
+ }
77
+
70
78
static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
71
79
return InterlockedExchangeAdd(ptr, inc);
72
80
}
@@ -100,6 +108,12 @@ static int sched_yield (void) {
100
108
Sleep (0);
101
109
return 0;
102
110
}
111
+
112
+ /* On windows we do not have semaphore.h we thus
113
+ rely on active polling in the threadpool
114
+ TODO: find better ? */
115
+ #define GGML_THREADPOOL_ACTIVE_POLL
116
+
103
117
#else
104
118
#include <pthread.h>
105
119
#include <stdatomic.h>
@@ -114,6 +128,68 @@ typedef void * thread_ret_t;
114
128
115
129
typedef pthread_t ggml_thread_t;
116
130
131
+ // Implementing a persistent thread-pool
132
+
133
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
134
+ #include <semaphore.h>
135
+ #endif
136
+
137
+ #define GGML_THREAD_POOL_SIZE 512
138
+
139
+ struct ggml_thread_pool_context {
140
+ /** Function to run */
141
+ void * (*fn)(void *);
142
+ /** Argument to pass*/
143
+ void * arg;
144
+ /** Return value */
145
+ void * ret;
146
+
147
+ /** At 1 if the thread is running*/
148
+ atomic_int executing;
149
+ /** Stop condition */
150
+ atomic_int running;
151
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
152
+ /** Used to pause threads in POSIX systems */
153
+ sem_t sem;
154
+ #endif
155
+ /** Used for identifying idle threads */
156
+ atomic_int flag;
157
+
158
+ /** Thread associated with this context */
159
+ ggml_thread_t thread;
160
+ /** Threads are created lazily using this flag */
161
+ short has_thread;
162
+ };
163
+
164
+ void ggml_thread_pool_context_init(struct ggml_thread_pool_context * ctx);
165
+
166
+ /** Main structure for the thread pool*/
167
+ struct ggml_thread_pool {
168
+ /** Each context is a lazily called thread */
169
+ struct ggml_thread_pool_context ctx[GGML_THREAD_POOL_SIZE];
170
+ };
171
+
172
+ /** Static instance of the GGML thread pool */
173
+ static struct ggml_thread_pool __thp;
174
+
175
+ /** This is the object representing a thread part of the threadpool*/
176
+ typedef struct ggml_thread_pool_thread_s {
177
+ /** Id of the thread (offset in threadpool array) -1 if external */
178
+ int id;
179
+ /** Handle for threads created externally */
180
+ ggml_thread_t th;
181
+ } ggml_thread_pool_thread_t;
182
+
183
+ /** This is the mainloop for threads in the threadpool */
184
+ static void * ggml_thread_pool_main(void * pctx);
185
+
186
+ /** Called once to initialize the threadpool */
187
+ void ggml_thread_pool_init(void);
188
+
189
+ /* known_index is a recommendation to try a given thread ID it allows to limit locking contention */
190
+ int ggml_thread_pool_create_thread(ggml_thread_pool_thread_t * th, void * (*fn)(void *), void * arg, int known_index);
191
+ int ggml_thread_pool_join_thread(ggml_thread_pool_thread_t th, void **retval);
192
+
117
193
#ifdef GGML_USE_CPU_HBM
118
194
#include <hbwmalloc.h>
119
195
#endif
@@ -1579,7 +1655,7 @@ struct ggml_compute_state_shared {
1579
1655
};
1580
1656
1581
1657
struct ggml_compute_state {
1582
- ggml_thread_t thrd;
1658
+ ggml_thread_pool_thread_t thrd;
1583
1659
int ith;
1584
1660
struct ggml_compute_state_shared* shared;
1585
1661
enum ggml_status ec;
@@ -3130,6 +3206,7 @@ static inline int ggml_up(int n, int m) {
3130
3206
3131
3207
////////////////////////////////////////////////////////////////////////////////
3132
3208
3209
+
3133
3210
struct ggml_context * ggml_init(struct ggml_init_params params) {
3134
3211
// make this function thread safe
3135
3212
ggml_critical_section_start();
@@ -3139,6 +3216,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3139
3216
if (is_first_call) {
3140
3217
// initialize time system (required on Windows)
3141
3218
ggml_time_init();
3219
+ ggml_thread_pool_init();
3142
3220
3143
3221
// initialize GELU, Quick GELU, SILU and EXP F32 tables
3144
3222
{
@@ -3249,6 +3327,7 @@ void ggml_free(struct ggml_context * ctx) {
3249
3327
// make this function thread safe
3250
3328
ggml_critical_section_start();
3251
3329
3330
+
3252
3331
bool found = false;
3253
3332
3254
3333
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
@@ -19305,6 +19384,130 @@ typedef int ggml_lock_t;
19305
19384
19306
19385
#endif
19307
19386
19387
+ /* Thread Pool Implementation */
19388
+
19389
+ void ggml_thread_pool_context_init(struct ggml_thread_pool_context * ctx) {
19390
+ memset(ctx, 0, sizeof(struct ggml_thread_pool_context));
19391
+
19392
+ ctx->running = 1;
19393
+ /* Initially no threads are created */
19394
+ ctx->has_thread = 0;
19395
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
19396
+ sem_init(&ctx->sem, 0, 0);
19397
+ #endif
19398
+ }
19399
+
19400
+ static void * ggml_thread_pool_main(void * pctx) {
19401
+ struct ggml_thread_pool_context * ctx = (struct ggml_thread_pool_context*)pctx;
19402
+
19403
+ while(ctx->running) {
19404
+ /* Here we wait for start notification */
19405
+ #ifdef GGML_THREADPOOL_ACTIVE_POLL
19406
+ /* Wait for run flag */
19407
+ while(atomic_load(&ctx->executing) == 0) {
19408
+ sched_yield();
19409
+ }
19410
+ #else
19411
+ sem_wait(&ctx->sem);
19412
+ #endif
19413
+ ctx->ret = NULL;
19414
+
19415
+ if(ctx->fn) {
19416
+ /* Call the actual function */
19417
+ ctx->ret = (ctx->fn)(ctx->arg);
19418
+ }
19419
+
19420
+ /* Flag done for join */
19421
+ atomic_store(&ctx->executing, 0);
19422
+
19423
+ }
19424
+
19425
+ return NULL;
19426
+ }
19427
+
19428
+ void ggml_thread_pool_init(void) {
19429
+ int i = 0;
19430
+
19431
+ for(i = 0 ; i < GGML_THREAD_POOL_SIZE ; i++) {
19432
+ /* Thread is running */
19433
+ ggml_thread_pool_context_init(&__thp.ctx[i]);
19434
+ }
19435
+ }
19436
+
19437
+ int ggml_thread_pool_create_thread(ggml_thread_pool_thread_t * th, void * (*fn)(void *), void * arg, int known_index) {
19438
+ /* Find a free thread */
19439
+ int i = 0;
19440
+
19441
+ if(known_index < 0) {
19442
+ known_index = 0;
19443
+ }
19444
+
19445
+ assert(known_index < GGML_THREAD_POOL_SIZE);
19446
+
19447
+ for( i = known_index ; i < GGML_THREAD_POOL_SIZE + known_index; i++) {
19448
+ int zero = 0;
19449
+ int targ = i % GGML_THREAD_POOL_SIZE;
19450
+ struct ggml_thread_pool_context * ctx = &__thp.ctx[targ];
19451
+
19452
+ if( atomic_compare_exchange_weak(&ctx->flag, &zero, 1) ) {
19453
+ /* We have the thread */
19454
+ ctx->fn = fn;
19455
+ ctx->arg = arg;
19456
+
19457
+ /* Save ID*/
19458
+ th->id = i;
19459
+
19460
+ atomic_store(&ctx->executing, 1);
19461
+
19462
+ /* Is this thread already created ? */
19463
+ if(!ctx->has_thread) {
19464
+ ggml_thread_create(&ctx->thread, NULL, ggml_thread_pool_main, &__thp.ctx[i]);
19465
+ ctx->has_thread = 1;
19466
+ }
19467
+
19468
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
19469
+ /* Signal start */
19470
+ sem_post(&ctx->sem);
19471
+ #endif
19472
+ return 0;
19473
+ }
19474
+ }
19475
+
19476
+ /* if we are here we failed to get from pool create a "normal" thread and flag it withj ID -1 for handling in join */
19477
+ th->id = -1;
19478
+ return ggml_thread_create(&th->th, NULL, fn, arg);
19479
+ }
19480
+
19481
+ int ggml_thread_pool_join_thread(ggml_thread_pool_thread_t th, void **retval)
19482
+ {
19483
+ /* Normal thread case */
19484
+ if(th.id < 0) {
19485
+ return ggml_thread_join(th.th, retval);
19486
+ }
19487
+
19488
+ struct ggml_thread_pool_context * ctx = &__thp.ctx[th.id];
19489
+
19490
+ /* Thread must be taken */
19491
+ assert(atomic_load(&ctx->flag) == 1);
19492
+
19493
+ while(atomic_load(&ctx->executing)) {
19494
+ sched_yield();
19495
+ }
19496
+
19497
+ /* Done executing if we are here*/
19498
+ if(retval) {
19499
+ *retval = ctx->ret;
19500
+ }
19501
+
19502
+ ctx->fn = NULL;
19503
+ ctx->arg = NULL;
19504
+
19505
+ /* Set the thread free */
19506
+ atomic_store(&ctx->flag, 0);
19507
+
19508
+ return 0;
19509
+ }
19510
+
19308
19511
// Android's libc implementation "bionic" does not support setting affinity
19309
19512
#if defined(__gnu_linux__)
19310
19513
static void set_numa_thread_affinity(int thread_n) {
@@ -20061,13 +20264,13 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
20061
20264
if (n_threads > 1) {
20062
20265
for (int j = 1; j < n_threads; ++j) {
20063
20266
workers[j] = (struct ggml_compute_state) {
20064
- .thrd = 0 ,
20267
+ .thrd = {0} ,
20065
20268
.ith = j,
20066
20269
.shared = &state_shared,
20067
20270
.ec = GGML_STATUS_SUCCESS,
20068
20271
};
20069
20272
20070
- const int rc = ggml_thread_create (&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
20273
+ const int rc = ggml_thread_pool_create_thread (&workers[j].thrd, ggml_graph_compute_thread, &workers[j], j );
20071
20274
GGML_ASSERT(rc == 0);
20072
20275
UNUSED(rc);
20073
20276
}
@@ -20090,7 +20293,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
20090
20293
// join or kill thread pool
20091
20294
if (n_threads > 1) {
20092
20295
for (int j = 1; j < n_threads; j++) {
20093
- const int rc = ggml_thread_join (workers[j].thrd, NULL);
20296
+ const int rc = ggml_thread_pool_join_thread (workers[j].thrd, NULL);
20094
20297
GGML_ASSERT(rc == 0);
20095
20298
if (workers[j].ec != GGML_STATUS_SUCCESS)
20096
20299
compute_status = workers[j].ec;
0 commit comments