diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index f74ad9ac33852..7f63fc1437872 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -9,7 +9,7 @@ # the only successful approach is to call cuda driver API in C. import dataclasses from contextlib import contextmanager -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -97,7 +97,7 @@ def use_memory_pool_with_allocator( new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): - yield mem_pool + yield mem_pool, new_alloc class CuMemAllocator: @@ -142,6 +142,7 @@ def get_instance() -> "CuMemAllocator": def __init__(self): self.pointer_to_data: Dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag + self.allocator_and_pools: Dict[str, Any] = {} def python_malloc_callback(self, allocation_handle: HandleType) -> None: """ @@ -231,7 +232,13 @@ def use_memory_pool(self, tag: Optional[str] = None): old_tag = self.current_tag self.current_tag = tag with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback): + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data yield # PyTorch's bug, calling torch.cuda.empty_cache() will error # when using pluggable allocator, see