-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
I am pretty sure the redesign won't happen unfortunately as I do not have the bandwidth to fix it. Revert so that the code is more easily found.
- Loading branch information
1 parent
0dd05ff
commit f2f1d79
Showing
23 changed files
with
2,011 additions
and
635 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from . import constants, gpus | ||
from .eager import EagerTensor | ||
from .errors import UnsupportedError | ||
from .interfaces import ( | ||
BatchedPair, | ||
BatchInfo, | ||
Runnable, | ||
RunnableTensor, | ||
TensorMixin, | ||
run, | ||
) | ||
from .lazy import Evaluation, LazyFunction, LazyTensor, lazy | ||
from .prepasses import CallBack, MetaData, PrePass, PrePassFunc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
from torch import dtype | ||
|
||
UNITS: Dict[str, int] = { | ||
"b": 1, | ||
"kb": 10 ** 3, | ||
"kib": 2 ** 10, | ||
"mb": 10 ** 6, | ||
"mib": 2 ** 20, | ||
"gb": 10 ** 9, | ||
"gib": 2 ** 30, | ||
"tb": 10 ** 4, | ||
"tib": 2 ** 40, | ||
} | ||
|
||
MEMORY_BYTES: Dict[dtype, int] = { | ||
torch.bool: 1, | ||
torch.uint8: 1, | ||
torch.int8: 1, | ||
torch.short: 2, | ||
torch.int16: 2, | ||
torch.int: 4, | ||
torch.int32: 4, | ||
torch.long: 8, | ||
torch.int64: 8, | ||
torch.half: 2, | ||
torch.float16: 2, | ||
torch.float: 4, | ||
torch.float32: 4, | ||
torch.double: 8, | ||
torch.float64: 8, | ||
} |
Empty file.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Any, Callable, Dict, Sequence, Tuple, Type | ||
|
||
from rich.logging import RichHandler | ||
from torch import Tensor | ||
from torch import device as Device | ||
from torch import dtype as DType | ||
|
||
from .interfaces import BatchInfo, RunnableTensor, TensorLike | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.addHandler(RichHandler()) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
# So, it seems that torch's Tensor base class utilizes metaclass | ||
# to pretend to be a parent of LongTensor, FloatTensor etc. | ||
# Perhaps I'll be using the same paradigm. | ||
|
||
|
||
class EagerTensor(RunnableTensor): | ||
def __init__(self, data: Tensor) -> None: | ||
self.data = data | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
return getattr(self.data, name) | ||
|
||
def batch(self) -> BatchInfo | None: | ||
raise NotImplementedError | ||
|
||
def run(self, partial: Tuple[int, int] | None = None) -> Tensor: | ||
del partial | ||
return self.data | ||
|
||
def visit(self, nodes: Dict[int, TensorLike]) -> None: | ||
raise NotImplementedError | ||
|
||
def device(self) -> str | Device: | ||
raise NotImplementedError | ||
|
||
def dtype(self) -> DType: | ||
raise NotImplementedError | ||
|
||
def size(self) -> Tuple[int, ...]: | ||
return self.data.size() | ||
|
||
@classmethod | ||
def __torch_function__( | ||
cls, | ||
func: Callable[..., Tensor], | ||
types: Tuple[Type[Any], ...], | ||
args: Sequence[TensorLike] = (), | ||
kwargs: Dict[str, TensorLike] | None = None, | ||
) -> TensorLike: | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
if not all(issubclass(typ, (Tensor, EagerTensor)) for typ in types): | ||
return NotImplemented | ||
|
||
return EagerTensor(func(*args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from typing import NoReturn | ||
|
||
|
||
class UnsupportedError(RuntimeError): | ||
"Sorry, this function is currently not supported." | ||
|
||
@classmethod | ||
def raise_error(cls, *args, **kwargs) -> NoReturn: | ||
del args | ||
del kwargs | ||
raise cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
from typing import Generator | ||
|
||
from pynvml.smi import nvidia_smi | ||
from torch import cuda | ||
|
||
from . import constants | ||
from .interfaces import BatchedPair | ||
|
||
NVSMI = None | ||
|
||
|
||
def nvidia_free_memory() -> int: | ||
""" | ||
Calls nvidia's nvml library and queries available GPU memory. | ||
Currently the function only works with 1 GPU. | ||
Returns | ||
------- | ||
Free GPU memory in terms of bytes. | ||
""" | ||
|
||
global NVSMI | ||
if NVSMI is None: | ||
NVSMI = nvidia_smi.getInstance() | ||
|
||
assert NVSMI is not None | ||
query = NVSMI.DeviceQuery("memory.free") | ||
|
||
# Only works on one GPU as of now. | ||
gpu = query["gpu"][0]["fb_memory_usage"] | ||
|
||
unit = constants.UNITS[gpu["unit"].lower()] | ||
free = gpu["free"] | ||
|
||
return free * unit | ||
|
||
|
||
def torch_free_memory() -> int: | ||
""" | ||
Calls torch's memory statistics to calculate the amount of GPU memory unused. | ||
Currently the function only works with 1 GPU. | ||
Returns | ||
------- | ||
Reserved GPU memory in terms of bytes. | ||
""" | ||
|
||
if not cuda.is_available(): | ||
return 0 | ||
|
||
# Only works on one GPU as of now. | ||
|
||
reserved_memory = cuda.memory_reserved(0) | ||
active_memory = cuda.memory_allocated(0) | ||
unused_memory = reserved_memory - active_memory | ||
return unused_memory | ||
|
||
|
||
def free_memory() -> int | None: | ||
""" | ||
The amount of free GPU memory that can be used. | ||
Returns | ||
------- | ||
Unused GPU memory, or None if no GPUs are available. | ||
""" | ||
|
||
if cuda.is_available(): | ||
return nvidia_free_memory() + torch_free_memory() | ||
else: | ||
return None | ||
|
||
|
||
def maximum_batch(memory: BatchedPair, total_memory: int | None = None) -> int | None: | ||
# batch * x + no_batch = unused_memoroy | ||
if total_memory is None: | ||
total_memory = free_memory() | ||
|
||
if total_memory is None: | ||
return None | ||
|
||
return (total_memory - memory.no_batch) // memory.batch | ||
|
||
|
||
def split_batch( | ||
memory: BatchedPair, current_batch: int, total_memory: int | None = None | ||
) -> Generator[int, None, None]: | ||
max_batch = maximum_batch(memory, total_memory) | ||
|
||
if max_batch is None: | ||
yield current_batch | ||
return | ||
|
||
batch_size = 2 ** (math.floor(math.log2(max_batch))) | ||
(times, current_batch) = divmod(current_batch, batch_size) | ||
|
||
for _ in range(times): | ||
yield batch_size | ||
|
||
while current_batch > 0: | ||
batch_size >>= 1 | ||
if current_batch >= batch_size: | ||
current_batch -= batch_size | ||
yield batch_size | ||
assert current_batch < batch_size, [current_batch, batch_size] |
Oops, something went wrong.