3
3
4
4
import torch
5
5
6
+ from vllm .platforms import current_platform
6
7
from vllm .utils import seed_everything
7
8
8
9
@@ -28,4 +29,25 @@ def set_weight_attrs(
28
29
for key , value in weight_attrs .items ():
29
30
assert not hasattr (
30
31
weight , key ), (f"Overwriting existing tensor attribute: { key } " )
32
+
33
+ # NOTE(woosuk): During weight loading, we often do something like:
34
+ # narrowed_tensor = param.data.narrow(0, offset, len)
35
+ # narrowed_tensor.copy_(real_weight)
36
+ # expecting narrowed_tensor and param.data to share the same storage.
37
+ # However, on TPUs, narrowed_tensor will lazily propagate to the base
38
+ # tensor, which is param.data, leading to the redundant memory usage.
39
+ # This sometimes causes OOM errors during model loading. To avoid this,
40
+ # we sync the param tensor after its weight loader is called.
41
+ # TODO(woosuk): Remove this hack once we have a better solution.
42
+ if current_platform .is_tpu () and key == "weight_loader" :
43
+ value = _make_synced_weight_loader (value )
31
44
setattr (weight , key , value )
45
+
46
+
47
+ def _make_synced_weight_loader (original_weight_loader ):
48
+
49
+ def _synced_weight_loader (param , * args , ** kwargs ):
50
+ original_weight_loader (param , * args , ** kwargs )
51
+ torch ._sync (param )
52
+
53
+ return _synced_weight_loader
0 commit comments