|
9 | 9 | # the only successful approach is to call cuda driver API in C.
|
10 | 10 | import dataclasses
|
11 | 11 | from contextlib import contextmanager
|
12 |
| -from typing import Callable, Dict, Optional, Tuple, Union |
| 12 | +from typing import Any, Callable, Dict, Optional, Tuple, Union |
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 |
|
@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
|
97 | 97 | new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
98 | 98 | mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
99 | 99 | with torch.cuda.memory.use_mem_pool(mem_pool):
|
100 |
| - yield mem_pool |
| 100 | + yield mem_pool, new_alloc |
101 | 101 |
|
102 | 102 |
|
103 | 103 | class CuMemAllocator:
|
@@ -142,6 +142,7 @@ def get_instance() -> "CuMemAllocator":
|
142 | 142 | def __init__(self):
|
143 | 143 | self.pointer_to_data: Dict[int, AllocationData] = {}
|
144 | 144 | self.current_tag: str = CuMemAllocator.default_tag
|
| 145 | + self.allocator_and_pools: Dict[str, Any] = {} |
145 | 146 |
|
146 | 147 | def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
147 | 148 | """
|
@@ -231,7 +232,13 @@ def use_memory_pool(self, tag: Optional[str] = None):
|
231 | 232 | old_tag = self.current_tag
|
232 | 233 | self.current_tag = tag
|
233 | 234 | with use_memory_pool_with_allocator(self.python_malloc_callback,
|
234 |
| - self.python_free_callback): |
| 235 | + self.python_free_callback) as data: |
| 236 | + # start to hit another PyTorch bug in PyTorch 2.6, |
| 237 | + # possibly because of gc-related issue w.r.t. the allocator and |
| 238 | + # the memory pool. |
| 239 | + # to avoid the issue, we keep a reference of the data. |
| 240 | + # see https://github.com/pytorch/pytorch/issues/146431 . |
| 241 | + self.allocator_and_pools[tag] = data |
235 | 242 | yield
|
236 | 243 | # PyTorch's bug, calling torch.cuda.empty_cache() will error
|
237 | 244 | # when using pluggable allocator, see
|
|
0 commit comments