Skip to content

Commit 485c652

Browse files
authored
[TPU] Call torch._sync(param) during weight loading (vllm-project#9437)
1 parent 4d647bd commit 485c652

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

vllm/model_executor/utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm.platforms import current_platform
67
from vllm.utils import seed_everything
78

89

@@ -28,4 +29,25 @@ def set_weight_attrs(
2829
for key, value in weight_attrs.items():
2930
assert not hasattr(
3031
weight, key), (f"Overwriting existing tensor attribute: {key}")
32+
33+
# NOTE(woosuk): During weight loading, we often do something like:
34+
# narrowed_tensor = param.data.narrow(0, offset, len)
35+
# narrowed_tensor.copy_(real_weight)
36+
# expecting narrowed_tensor and param.data to share the same storage.
37+
# However, on TPUs, narrowed_tensor will lazily propagate to the base
38+
# tensor, which is param.data, leading to the redundant memory usage.
39+
# This sometimes causes OOM errors during model loading. To avoid this,
40+
# we sync the param tensor after its weight loader is called.
41+
# TODO(woosuk): Remove this hack once we have a better solution.
42+
if current_platform.is_tpu() and key == "weight_loader":
43+
value = _make_synced_weight_loader(value)
3144
setattr(weight, key, value)
45+
46+
47+
def _make_synced_weight_loader(original_weight_loader):
48+
49+
def _synced_weight_loader(param, *args, **kwargs):
50+
original_weight_loader(param, *args, **kwargs)
51+
torch._sync(param)
52+
53+
return _synced_weight_loader

0 commit comments

Comments
 (0)