diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3918e3e867695..9d05ff4c2cfdd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -531,6 +531,7 @@ steps: - pip uninstall vllm_add_dummy_platform -y # end platform plugin tests # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py new file mode 100644 index 0000000000000..84688cee96609 --- /dev/null +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.core.scheduler import Scheduler + + +class DummyScheduler(Scheduler): + + def schedule(self): + raise Exception("Exception raised by DummyScheduler") + + +def test_scheduler_plugins(): + import pytest + + from vllm.engine.arg_utils import EngineArgs + from vllm.engine.llm_engine import LLMEngine + from vllm.sampling_params import SamplingParams + + with pytest.raises(Exception) as exception_info: + + engine_args = EngineArgs( + model="facebook/opt-125m", + enforce_eager=True, # reduce test time + scheduler_cls=DummyScheduler, + ) + + engine = LLMEngine.from_engine_args(engine_args=engine_args) + + sampling_params = SamplingParams(max_tokens=1) + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert str(exception_info.value) == "Exception raised by DummyScheduler" diff --git a/vllm/config.py b/vllm/config.py index 59fa60fd8b0c2..56315aacbe517 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1495,6 +1495,10 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5f076f05d0465..78681008b62ef 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -192,6 +192,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None @@ -938,6 +939,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'priority (lower value means earlier handling) and time of ' 'arrival deciding any ties).') + parser.add_argument( + '--scheduler-cls', + default=EngineArgs.scheduler_cls, + help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' + 'is the default scheduler. Can be a class directly or the path to ' + 'a class of form "mod.custom_class".') + parser.add_argument( '--override-neuron-config', type=json.loads, @@ -1273,10 +1281,12 @@ def create_engine_config(self, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, + scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, ) + lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2e5bc75c6db38..3ce9a0461368d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -19,8 +19,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, - SchedulerOutputs) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.output_processor.interfaces import ( @@ -58,7 +57,8 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import (Counter, Device, deprecate_kwargs, + resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -346,6 +346,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls self.scheduler = [ Scheduler( self.scheduler_config, self.cache_config, self.lora_config,