Skip to content

Commit

Permalink
Add --buffer-size argument
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 11, 2024
1 parent 0c7af7a commit 5a26174
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 1 deletion.
4 changes: 4 additions & 0 deletions candle-holder-serve/entrypoint-cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ if [ ! -z "$CANDLE_HOLDER_NUM_WORKERS" ]; then
ARGS+=("--num-workers" "$CANDLE_HOLDER_NUM_WORKERS")
fi

if [ ! -z "$CANDLE_HOLDER_BUFFER_SIZE" ]; then
ARGS+=("--buffer-size" "$CANDLE_HOLDER_BUFFER_SIZE")
fi

CANDLE_HOLDER_HOST=${CANDLE_HOLDER_HOST:-0.0.0.0:8080}
ARGS+=("--host" "$CANDLE_HOLDER_HOST")

Expand Down
4 changes: 4 additions & 0 deletions candle-holder-serve/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ if [ ! -z "$CANDLE_HOLDER_NUM_WORKERS" ]; then
ARGS+=("--num-workers" "$CANDLE_HOLDER_NUM_WORKERS")
fi

if [ ! -z "$CANDLE_HOLDER_BUFFER_SIZE" ]; then
ARGS+=("--buffer-size" "$CANDLE_HOLDER_BUFFER_SIZE")
fi

CANDLE_HOLDER_HOST=${CANDLE_HOLDER_HOST:-0.0.0.0:8080}
ARGS+=("--host" "$CANDLE_HOLDER_HOST")

Expand Down
8 changes: 8 additions & 0 deletions candle-holder-serve/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ pub(crate) struct Cli {
/// The number of workers to use for inference.
#[arg(long, default_value = "1")]
num_workers: usize,

/// Channel buffer size for the inference worker.
#[arg(long, default_value = "32")]
buffer_size: usize,
}

impl Cli {
Expand All @@ -49,6 +53,10 @@ impl Cli {
self.num_workers
}

pub fn buffer_size(&self) -> usize {
self.buffer_size
}

/// Get the [`candle_core::Device`] corresponding to the selected device option.
///
/// # Errors
Expand Down
3 changes: 2 additions & 1 deletion candle-holder-serve/src/router_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ macro_rules! generate_router {
tracing::error!("Failed to warm up the model: {}", e);
});

tracing::info!("Channel buffer size: {}", args.buffer_size());
let (tx, rx) =
mpsc::channel::<InferenceTask<$request, Result<$response, ErrorResponse>>>(32);
mpsc::channel::<InferenceTask<$request, Result<$response, ErrorResponse>>>(args.buffer_size());

tokio::spawn(task_distributor::<
$pipeline,
Expand Down

0 comments on commit 5a26174

Please # to comment.