Skip to content

Commit 4e9c8f9

Browse files
Isotr0pyrickyyx
authored andcommitted
[Misc] Add uninitialized params tracking for AutoWeightsLoader (vllm-project#10327)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: rickyx <rickyx@anyscale.com>
1 parent 48443fb commit 4e9c8f9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+454
-185
lines changed

vllm/model_executor/model_loader/loader.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,17 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
334334
with target_device:
335335
model = _initialize_model(vllm_config=vllm_config)
336336

337-
model.load_weights(self._get_all_weights(model_config, model))
337+
weights_to_load = {name for name, _ in model.named_parameters()}
338+
loaded_weights = model.load_weights(
339+
self._get_all_weights(model_config, model))
340+
# We only enable strict check for non-quantiized models
341+
# that have loaded weights tracking currently.
342+
if model_config.quantization is None and loaded_weights is not None:
343+
weights_not_loaded = weights_to_load - loaded_weights
344+
if weights_not_loaded:
345+
raise ValueError(
346+
"Following weights were not initialized from "
347+
f"checkpoint: {weights_not_loaded}")
338348

339349
for _, module in model.named_modules():
340350
quant_method = getattr(module, "quant_method", None)

vllm/model_executor/models/arctic.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Inference-only Snowflake Arctic model."""
2-
from typing import Iterable, List, Optional, Tuple, Union
2+
from typing import Iterable, List, Optional, Set, Tuple, Union
33

44
import torch
55
from torch import nn
@@ -480,7 +480,8 @@ def sample(
480480
next_tokens = self.sampler(logits, sampling_metadata)
481481
return next_tokens
482482

483-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
483+
def load_weights(self, weights: Iterable[Tuple[str,
484+
torch.Tensor]]) -> Set[str]:
484485
stacked_params_mapping = [
485486
# (param_name, shard_name, shard_id)
486487
("qkv_proj", "q_proj", "q"),
@@ -518,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
518519
("ws", f"experts.{expert_id}.w3.weight", expert_id))
519520

520521
params_dict = dict(self.named_parameters())
522+
loaded_params: Set[str] = set()
521523

522524
logger.info(
523525
"It will take ~10 minutes loading from the 16-bit weights. "
@@ -573,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
573575
weight_loader = getattr(param, "weight_loader",
574576
default_weight_loader)
575577
weight_loader(param, loaded_weight)
578+
loaded_params.add(name)
579+
return loaded_params

vllm/model_executor/models/baichuan.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# limitations under the License.
1919
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
2020
import math
21-
from typing import Iterable, List, Optional, Tuple, Union
21+
from typing import Iterable, List, Optional, Set, Tuple, Union
2222

2323
import torch
2424
from torch import nn
@@ -404,13 +404,15 @@ def sample(
404404
next_tokens = self.sampler(logits, sampling_metadata)
405405
return next_tokens
406406

407-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
407+
def load_weights(self, weights: Iterable[Tuple[str,
408+
torch.Tensor]]) -> Set[str]:
408409
stacked_params_mapping = [
409410
# (param_name, shard_name, shard_id)
410411
("gate_up_proj", "gate_proj", 0),
411412
("gate_up_proj", "up_proj", 1),
412413
]
413414
params_dict = dict(self.named_parameters())
415+
loaded_params: Set[str] = set()
414416
for name, loaded_weight in weights:
415417
if "rotary_emb.inv_freq" in name:
416418
continue
@@ -449,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
449451
weight_loader = getattr(param, "weight_loader",
450452
default_weight_loader)
451453
weight_loader(param, loaded_weight)
454+
loaded_params.add(name)
455+
return loaded_params
452456

453457

454458
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):

vllm/model_executor/models/bert.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, List, Optional, Tuple
1+
from typing import Iterable, List, Optional, Set, Tuple
22

33
import torch
44
from torch import nn
@@ -337,7 +337,8 @@ def forward(
337337

338338
return self.encoder(hidden_states, kv_caches, attn_metadata)
339339

340-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
340+
def load_weights(self, weights: Iterable[Tuple[str,
341+
torch.Tensor]]) -> Set[str]:
341342
stacked_params_mapping = [
342343
# (param_name, shard_name, shard_id)
343344
("qkv_proj", "query", "q"),
@@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
346347
]
347348

348349
params_dict = dict(self.named_parameters())
350+
loaded_params: Set[str] = set()
349351
for name, loaded_weight in weights:
350352
if "pooler" in name:
351353
continue
@@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
368370
weight_loader = getattr(param, "weight_loader",
369371
default_weight_loader)
370372
weight_loader(param, loaded_weight)
373+
loaded_params.add(name)
374+
return loaded_params
371375

372376

373377
class BertEmbeddingModel(nn.Module):

vllm/model_executor/models/blip.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Minimal implementation of BlipVisionModel intended to be only used
22
within a vision language model."""
3-
from typing import Iterable, Optional, Tuple, Union
3+
from typing import Iterable, Optional, Set, Tuple, Union
44

55
import torch
66
import torch.nn as nn
@@ -415,14 +415,16 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
415415

416416
return self.post_layernorm(hidden_states)
417417

418-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
418+
def load_weights(self, weights: Iterable[Tuple[str,
419+
torch.Tensor]]) -> Set[str]:
419420
stacked_params_mapping = [
420421
# (param_name, shard_name, shard_id)
421422
("qkv_proj", "q_proj", "q"),
422423
("qkv_proj", "k_proj", "k"),
423424
("qkv_proj", "v_proj", "v"),
424425
] if self.shard_weight else []
425426
params_dict = dict(self.named_parameters())
427+
loaded_params: Set[str] = set()
426428
layer_count = len(self.encoder.layers)
427429

428430
for name, loaded_weight in weights:
@@ -440,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
440442
for (param_name, weight_name, shard_id) in stacked_params_mapping:
441443
if weight_name not in name:
442444
continue
443-
444-
param = params_dict[name.replace(weight_name, param_name)]
445+
name = name.replace(weight_name, param_name)
446+
param = params_dict[name]
445447
weight_loader = param.weight_loader
446448
weight_loader(param, loaded_weight, shard_id)
447449
break
@@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
450452
weight_loader = getattr(param, "weight_loader",
451453
default_weight_loader)
452454
weight_loader(param, loaded_weight)
455+
loaded_params.add(name)
456+
return loaded_params

vllm/model_executor/models/blip2.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import cached_property
2-
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
2+
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
33
TypedDict, Union)
44

55
import torch
@@ -692,6 +692,7 @@ def sample(
692692
) -> Optional[SamplerOutput]:
693693
return self.language_model.sample(logits, sampling_metadata)
694694

695-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
695+
def load_weights(self, weights: Iterable[Tuple[str,
696+
torch.Tensor]]) -> Set[str]:
696697
loader = AutoWeightsLoader(self)
697-
loader.load_weights(weights)
698+
return loader.load_weights(weights)

vllm/model_executor/models/bloom.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
"""Inference-only BLOOM model compatible with HuggingFace weights."""
1818
import math
19-
from typing import Iterable, List, Optional, Tuple, Union
19+
from typing import Iterable, List, Optional, Set, Tuple, Union
2020

2121
import torch
2222
from torch import nn
@@ -341,8 +341,10 @@ def sample(
341341
next_tokens = self.sampler(logits, sampling_metadata)
342342
return next_tokens
343343

344-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
344+
def load_weights(self, weights: Iterable[Tuple[str,
345+
torch.Tensor]]) -> Set[str]:
345346
params_dict = dict(self.named_parameters(remove_duplicate=False))
347+
loaded_params: Set[str] = set()
346348
for name, loaded_weight in weights:
347349
if name == "lm_head.weight":
348350
continue
@@ -371,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
371373
weight_loader = getattr(param, "weight_loader",
372374
default_weight_loader)
373375
weight_loader(param, loaded_weight)
376+
loaded_params.add(name)
377+
return loaded_params

vllm/model_executor/models/chameleon.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import cached_property
2-
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
2+
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
33
Tuple, TypedDict, Union)
44

55
import torch
@@ -1034,7 +1034,8 @@ def sample(
10341034
next_tokens = self.sampler(logits, sampling_metadata)
10351035
return next_tokens
10361036

1037-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1037+
def load_weights(self, weights: Iterable[Tuple[str,
1038+
torch.Tensor]]) -> Set[str]:
10381039
stacked_params_mapping = [
10391040
# (param_name, shard_name, shard_id)
10401041
(".qkv_proj", ".q_proj", "q"),
@@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
10441045
(".gate_up_proj", ".up_proj", 1),
10451046
]
10461047
params_dict = dict(self.named_parameters())
1048+
loaded_params: Set[str] = set()
10471049
for name, loaded_weight in weights:
10481050
if "rotary_emb.inv_freq" in name:
10491051
continue
@@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
11111113
weight_loader = getattr(param, "weight_loader",
11121114
default_weight_loader)
11131115
weight_loader(param, loaded_weight)
1116+
loaded_params.add(name)
1117+
return loaded_params

vllm/model_executor/models/chatglm.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"""Inference-only ChatGLM model compatible with THUDM weights."""
44
from argparse import Namespace
55
from array import array
6-
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
6+
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
7+
TypedDict)
78

89
import torch
910
from PIL import Image
@@ -645,7 +646,8 @@ def sample(
645646
next_tokens = self.sampler(logits, sampling_metadata)
646647
return next_tokens
647648

648-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
649+
def load_weights(self, weights: Iterable[Tuple[str,
650+
torch.Tensor]]) -> Set[str]:
649651
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
650652
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
651653
"transformer.vision.linear_proj.merged_proj.weight": {
@@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
655657
}
656658

657659
params_dict = dict(self.named_parameters(remove_duplicate=False))
660+
loaded_params: Set[str] = set()
658661
for name, loaded_weight in weights:
659662
is_weight_to_be_merge = False
660663
for _, merged_weight_dict in merged_weights_dict.items():
@@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
677680
weight_loader = getattr(param, "weight_loader",
678681
default_weight_loader)
679682
weight_loader(param, loaded_weight)
683+
loaded_params.add(name)
680684

681685
for combined_name, merged_weight_dict in merged_weights_dict.items():
682686
if combined_name in params_dict:
@@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
686690
weight_loader = getattr(param, "weight_loader",
687691
default_weight_loader)
688692
weight_loader(param, combined_weight)
693+
loaded_params.add(combined_name)
694+
return loaded_params

vllm/model_executor/models/clip.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Minimal implementation of CLIPVisionModel intended to be only used
22
within a vision language model."""
3-
from typing import Iterable, List, Optional, Tuple, Union
3+
from typing import Iterable, List, Optional, Set, Tuple, Union
44

55
import numpy as np
66
import torch
@@ -483,14 +483,16 @@ def device(self):
483483

484484
# (TODO) Add prefix argument for filtering out weights to be loaded
485485
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
486-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
486+
def load_weights(self, weights: Iterable[Tuple[str,
487+
torch.Tensor]]) -> Set[str]:
487488
stacked_params_mapping = [
488489
# (param_name, shard_name, shard_id)
489490
("qkv_proj", "q_proj", "q"),
490491
("qkv_proj", "k_proj", "k"),
491492
("qkv_proj", "v_proj", "v"),
492493
] if self.shard_weight else []
493494
params_dict = dict(self.named_parameters())
495+
loaded_params: Set[str] = set()
494496
layer_count = len(self.vision_model.encoder.layers)
495497

496498
for name, loaded_weight in weights:
@@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
508510
for (param_name, weight_name, shard_id) in stacked_params_mapping:
509511
if weight_name not in name:
510512
continue
513+
name = name.replace(weight_name, param_name)
511514

512-
param = params_dict[name.replace(weight_name, param_name)]
515+
param = params_dict[name]
513516
weight_loader = param.weight_loader
514517
weight_loader(param, loaded_weight, shard_id)
515518
break
@@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
518521
weight_loader = getattr(param, "weight_loader",
519522
default_weight_loader)
520523
weight_loader(param, loaded_weight)
524+
loaded_params.add(name)
525+
return loaded_params

vllm/model_executor/models/commandr.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ def sample(
402402
next_tokens = self.sampler(logits, sampling_metadata)
403403
return next_tokens
404404

405-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
405+
def load_weights(self, weights: Iterable[Tuple[str,
406+
torch.Tensor]]) -> Set[str]:
406407
stacked_params_mapping = [
407408
# (param_name, shard_name, shard_id)
408409
("qkv_proj", "q_proj", "q"),
@@ -447,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
447448
default_weight_loader)
448449
weight_loader(param, loaded_weight)
449450
loaded_params.add(name)
451+
return loaded_params

vllm/model_executor/models/dbrx.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, List, Optional, Tuple, Union
1+
from typing import Iterable, List, Optional, Set, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -417,13 +417,15 @@ def sample(
417417
next_tokens = self.sampler(logits, sampling_metadata)
418418
return next_tokens
419419

420-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
420+
def load_weights(self, weights: Iterable[Tuple[str,
421+
torch.Tensor]]) -> Set[str]:
421422

422423
expert_params_mapping = [(
423424
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
424425
f"mlp.{weight_name}",
425426
) for weight_name in ["w1", "v1", "w2"]]
426427
params_dict = dict(self.named_parameters(remove_duplicate=False))
428+
loaded_params: Set[str] = set()
427429
for name, loaded_weight in weights:
428430
for param_name, weight_name in expert_params_mapping:
429431
if weight_name not in name:
@@ -447,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
447449
weight_loader = getattr(param, "weight_loader",
448450
default_weight_loader)
449451
weight_loader(param, loaded_weight)
452+
loaded_params.add(name)
453+
return loaded_params

0 commit comments

Comments
 (0)