Skip to content

Support for additional vLLM server arguments #89

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions vec_inf/cli/_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Command line interface for Vector Inference."""

import time
from typing import Optional, Union
from typing import Optional, Union, cast

import click
from rich.console import Console
Expand Down Expand Up @@ -125,36 +125,71 @@ def cli() -> None:
is_flag=True,
help="Output in JSON string",
)
@click.option(
"--vllm-arg",
multiple=True,
help='Extra vLLM server args (use --vllm-arg "--foo=bar")',
)
def launch(
model_name: str,
vllm_arg: tuple[str],
json_mode: bool,
**cli_kwargs: Optional[Union[str, int, float, bool]],
) -> None:
"""Launch a model on the cluster."""
try:
# Convert cli_kwargs to LaunchOptions
kwargs = {k: v for k, v in cli_kwargs.items() if k != "json_mode"}
# Cast the dictionary to LaunchOptionsDict
options_dict: LaunchOptionsDict = kwargs # type: ignore
# Parse extra vLLM args
vllm_optional_args = _parse_vllm_optional_args(vllm_arg)

# Prepare LaunchOptions
kwargs: dict[
str, Union[str, int, float, bool, dict[str, Union[str, int, float, bool]]]
] = {k: v for k, v in cli_kwargs.items() if v is not None}
kwargs["vllm_optional_args"] = vllm_optional_args

options_dict: LaunchOptionsDict = cast(LaunchOptionsDict, kwargs)
launch_options = LaunchOptions(**options_dict)

# Start the client and launch model inference server
# Launch
client = VecInfClient()
launch_response = client.launch_model(model_name, launch_options)

# Display launch information
launch_formatter = LaunchResponseFormatter(model_name, launch_response.config)
if cli_kwargs.get("json_mode"):
formatter = LaunchResponseFormatter(model_name, launch_response.config)
if json_mode:
click.echo(launch_response.config)
else:
launch_info_table = launch_formatter.format_table_output()
CONSOLE.print(launch_info_table)
CONSOLE.print(formatter.format_table_output())

except click.ClickException as e:
raise e
except Exception as e:
raise click.ClickException(f"Launch failed: {str(e)}") from e


def _parse_vllm_optional_args(
vllm_arg: tuple[str],
) -> dict[str, Union[str, int, float, bool]]:
parsed: dict[str, Union[str, int, float, bool]] = {}
for raw_arg in vllm_arg:
arg = raw_arg.removeprefix("--")
if "=" in arg:
key, val = arg.split("=", maxsplit=1)
if val.lower() == "true":
parsed[key.replace("-", "_")] = True
elif val.lower() == "false":
parsed[key.replace("-", "_")] = False
elif val.isdigit():
parsed[key.replace("-", "_")] = int(val)
else:
try:
parsed[key.replace("-", "_")] = float(val)
except ValueError:
parsed[key.replace("-", "_")] = val
else:
parsed[arg.replace("-", "_")] = True
return parsed


@cli.command("status")
@click.argument("slurm_job_id", type=int, nargs=1)
@click.option(
Expand Down
7 changes: 7 additions & 0 deletions vec_inf/cli/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def format_table_output(self) -> Table:
)
if self.params.get("enforce_eager"):
table.add_row("Enforce Eager", self.params["enforce_eager"])

vllm_optional_args = self.params.get("vllm_optional_args")
if isinstance(vllm_optional_args, dict):
for key, val in vllm_optional_args.items():
label = f"VLLM: {key.replace('_', ' ').title()}"
table.add_row(label, str(val))

table.add_row(
"Model Weights Directory",
str(Path(self.params["model_weights_parent_dir"], self.model_name)),
Expand Down
16 changes: 12 additions & 4 deletions vec_inf/client/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
"""
self.model_name = model_name
self.kwargs = kwargs or {}
self.vllm_optional_args: dict[str, Any] = (
self.kwargs.pop("vllm_optional_args", {}) or {}
)
self.slurm_job_id = ""
self.slurm_script_path = Path("")
self.model_config = self._get_model_configuration()
Expand Down Expand Up @@ -169,7 +172,7 @@ def _build_launch_command(self) -> str:
)
# Add slurm script
self.slurm_script_path = SlurmScriptGenerator(
self.params, SRC_DIR
self.params, self.vllm_optional_args, SRC_DIR
).write_to_log_dir()
command_list.append(str(self.slurm_script_path))
return " ".join(command_list)
Expand Down Expand Up @@ -204,13 +207,18 @@ def launch(self) -> LaunchResponse:
job_log_dir / f"{self.model_name}.{self.slurm_job_id}.slurm"
)

with job_json.open("w") as file:
json.dump(self.params, file, indent=4)
json_payload = {
**self.params,
**self.vllm_optional_args,
}

with job_json.open("w") as f:
json.dump(json_payload, f, indent=4)

return LaunchResponse(
slurm_job_id=int(self.slurm_job_id),
model_name=self.model_name,
config=self.params,
config={**self.params, "vllm_optional_args": self.vllm_optional_args},
raw_output=command_output,
)

Expand Down
2 changes: 2 additions & 0 deletions vec_inf/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class LaunchOptions:
pipeline_parallelism: Optional[bool] = None
compilation_config: Optional[str] = None
enforce_eager: Optional[bool] = None
vllm_optional_args: Optional[dict[str, Union[str, int, float, bool]]] = None


class LaunchOptionsDict(TypedDict):
Expand All @@ -115,6 +116,7 @@ class LaunchOptionsDict(TypedDict):
pipeline_parallelism: NotRequired[Optional[bool]]
compilation_config: NotRequired[Optional[str]]
enforce_eager: NotRequired[Optional[bool]]
vllm_optional_args: NotRequired[dict[str, Union[str, int, float, bool]]]


@dataclass
Expand Down
13 changes: 12 additions & 1 deletion vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@


class SlurmScriptGenerator:
def __init__(self, params: dict[str, Any], src_dir: str):
def __init__(
self, params: dict[str, Any], vllm_optional_args: dict[str, Any], src_dir: str
):
self.params = params
self.vllm_optional_args = vllm_optional_args
self.src_dir = src_dir
self.is_multinode = int(self.params["num_nodes"]) > 1
self.model_weights_path = str(
Expand Down Expand Up @@ -74,6 +77,14 @@ def _generate_shared_args(self) -> str:
if self.params.get("enforce_eager") == "True":
args.append("--enforce-eager")

for key, value in self.vllm_optional_args.items():
cli_key = key.replace("_", "-")
if isinstance(value, bool):
if value:
args.append(f"--{cli_key} \\")
else:
args.append(f"--{cli_key} {value} \\")

return "\n".join(args)

def _generate_server_script(self) -> str:
Expand Down
Loading