From 2cbca03ff6516360b9dc04a281684161902a33b9 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 30 Sep 2024 13:53:42 +0000 Subject: [PATCH 1/2] add process_weights_after_loading for DummyLoader --- vllm/model_executor/model_loader/loader.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c21b10d661ecc..dfebd81911219 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -441,6 +441,18 @@ def load_model(self, *, model_config: ModelConfig, # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) return model.eval() From 1231279c3f266559a6a7bbe3414f39f40a3dc049 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 30 Sep 2024 14:10:05 +0000 Subject: [PATCH 2/2] yapf fix --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dfebd81911219..8fed5267a9eb5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -441,7 +441,7 @@ def load_model(self, *, model_config: ModelConfig, # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) - + for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: