diff --git a/candle-holder-serve/entrypoint-cuda.sh b/candle-holder-serve/entrypoint-cuda.sh index 895eb53..cf4101e 100644 --- a/candle-holder-serve/entrypoint-cuda.sh +++ b/candle-holder-serve/entrypoint-cuda.sh @@ -35,6 +35,10 @@ if [ ! -z "$CANDLE_HOLDER_NUM_WORKERS" ]; then ARGS+=("--num-workers" "$CANDLE_HOLDER_NUM_WORKERS") fi +if [ ! -z "$CANDLE_HOLDER_BUFFER_SIZE" ]; then + ARGS+=("--buffer-size" "$CANDLE_HOLDER_BUFFER_SIZE") +fi + CANDLE_HOLDER_HOST=${CANDLE_HOLDER_HOST:-0.0.0.0:8080} ARGS+=("--host" "$CANDLE_HOLDER_HOST") diff --git a/candle-holder-serve/entrypoint.sh b/candle-holder-serve/entrypoint.sh index 32e1f2f..4cec613 100644 --- a/candle-holder-serve/entrypoint.sh +++ b/candle-holder-serve/entrypoint.sh @@ -23,6 +23,14 @@ if [ ! -z "$CANDLE_HOLDER_DTYPE" ]; then ARGS+=("--dtype" "$CANDLE_HOLDER_DTYPE") fi +if [ ! -z "$CANDLE_HOLDER_NUM_WORKERS" ]; then + ARGS+=("--num-workers" "$CANDLE_HOLDER_NUM_WORKERS") +fi + +if [ ! -z "$CANDLE_HOLDER_BUFFER_SIZE" ]; then + ARGS+=("--buffer-size" "$CANDLE_HOLDER_BUFFER_SIZE") +fi + CANDLE_HOLDER_HOST=${CANDLE_HOLDER_HOST:-0.0.0.0:8080} ARGS+=("--host" "$CANDLE_HOLDER_HOST") diff --git a/candle-holder-serve/src/cli.rs b/candle-holder-serve/src/cli.rs index 7947877..13ae5b0 100644 --- a/candle-holder-serve/src/cli.rs +++ b/candle-holder-serve/src/cli.rs @@ -30,6 +30,10 @@ pub(crate) struct Cli { /// The number of workers to use for inference. #[arg(long, default_value = "1")] num_workers: usize, + + /// Channel buffer size for the inference worker. + #[arg(long, default_value = "32")] + buffer_size: usize, } impl Cli { @@ -49,6 +53,10 @@ impl Cli { self.num_workers } + pub fn buffer_size(&self) -> usize { + self.buffer_size + } + /// Get the [`candle_core::Device`] corresponding to the selected device option. /// /// # Errors diff --git a/candle-holder-serve/src/router_macro.rs b/candle-holder-serve/src/router_macro.rs index 64c078d..8b4e4a6 100644 --- a/candle-holder-serve/src/router_macro.rs +++ b/candle-holder-serve/src/router_macro.rs @@ -32,8 +32,9 @@ macro_rules! generate_router { tracing::error!("Failed to warm up the model: {}", e); }); + tracing::info!("Channel buffer size: {}", args.buffer_size()); let (tx, rx) = - mpsc::channel::>>(32); + mpsc::channel::>>(args.buffer_size()); tokio::spawn(task_distributor::< $pipeline,