Skip to content

Commit 28537c4

Browse files
youkaichaojimpang
authored and
jimpang
committed
[core] fix sleep mode in pytorch 2.6 (vllm-project#13456)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 7fbc0df commit 28537c4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

vllm/device_allocator/cumem.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# the only successful approach is to call cuda driver API in C.
1010
import dataclasses
1111
from contextlib import contextmanager
12-
from typing import Callable, Dict, Optional, Tuple, Union
12+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1313

1414
import torch
1515

@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
9797
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
9898
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
9999
with torch.cuda.memory.use_mem_pool(mem_pool):
100-
yield mem_pool
100+
yield mem_pool, new_alloc
101101

102102

103103
class CuMemAllocator:
@@ -142,6 +142,7 @@ def get_instance() -> "CuMemAllocator":
142142
def __init__(self):
143143
self.pointer_to_data: Dict[int, AllocationData] = {}
144144
self.current_tag: str = CuMemAllocator.default_tag
145+
self.allocator_and_pools: Dict[str, Any] = {}
145146

146147
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
147148
"""
@@ -231,7 +232,13 @@ def use_memory_pool(self, tag: Optional[str] = None):
231232
old_tag = self.current_tag
232233
self.current_tag = tag
233234
with use_memory_pool_with_allocator(self.python_malloc_callback,
234-
self.python_free_callback):
235+
self.python_free_callback) as data:
236+
# start to hit another PyTorch bug in PyTorch 2.6,
237+
# possibly because of gc-related issue w.r.t. the allocator and
238+
# the memory pool.
239+
# to avoid the issue, we keep a reference of the data.
240+
# see https://github.com/pytorch/pytorch/issues/146431 .
241+
self.allocator_and_pools[tag] = data
235242
yield
236243
# PyTorch's bug, calling torch.cuda.empty_cache() will error
237244
# when using pluggable allocator, see

0 commit comments

Comments
 (0)