Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently we first load PyTorch params map (either from
.bin
or.safetensors
) with all tensors materialized, which already takes as much memory as all the params. Then when building the Axon params map, we look for tensors in the PyTorch map and oftentimes need to apply transformations (mostlyNx.transpose/1
), so for each such tensor we use further memory. Consequently the memory peak can be almost double the size of the parameters.Because of this behaviour, loading large models directly onto the GPU could result in OOM. In such cases we recommended loading the params onto the CPU first and only then transferring, but this (a) assumes we have enough RAM (which for really large models is not necessarily the case!); (b) puts high pressure on RAM; (c) is slower since those
Nx.transpose/1
calls are on the CPU.With this PR, instead of loading
%{"param" => %Nx.Tensor{}}
from.bin
, we load%{"param" => %FileTensor{}}
, whereFileTensor
is a lazy container. I also added an option tosafetensors
to do the same (elixir-nx/safetensors#9). So now when building an Axon param tensor, we lookup the relevant PyTorch lazy containers, callNx.to_tensor/1
to materialize them, do the necessary transformations. Then we proceed to the next param and the past intermediate tensors can already be garbage collected. This way there is barely any memory footprint other than the params themselves.I tested loading Llama2 onto the GPU with different values of EXLA
:memory_fraction
to force an upper memory limit. The parameters are 13.5GB, prior to this change loading required 24.6GB, now 13.6GB was enough.