Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Load parameter tensors lazily #344

Merged
merged 1 commit into from
Feb 23, 2024
Merged

Load parameter tensors lazily #344

merged 1 commit into from
Feb 23, 2024

Conversation

jonatanklosko
Copy link
Member

Currently we first load PyTorch params map (either from .bin or .safetensors) with all tensors materialized, which already takes as much memory as all the params. Then when building the Axon params map, we look for tensors in the PyTorch map and oftentimes need to apply transformations (mostly Nx.transpose/1), so for each such tensor we use further memory. Consequently the memory peak can be almost double the size of the parameters.

Because of this behaviour, loading large models directly onto the GPU could result in OOM. In such cases we recommended loading the params onto the CPU first and only then transferring, but this (a) assumes we have enough RAM (which for really large models is not necessarily the case!); (b) puts high pressure on RAM; (c) is slower since those Nx.transpose/1 calls are on the CPU.

With this PR, instead of loading %{"param" => %Nx.Tensor{}} from .bin, we load %{"param" => %FileTensor{}}, where FileTensor is a lazy container. I also added an option to safetensors to do the same (elixir-nx/safetensors#9). So now when building an Axon param tensor, we lookup the relevant PyTorch lazy containers, call Nx.to_tensor/1 to materialize them, do the necessary transformations. Then we proceed to the next param and the past intermediate tensors can already be garbage collected. This way there is barely any memory footprint other than the params themselves.

I tested loading Llama2 onto the GPU with different values of EXLA :memory_fraction to force an upper memory limit. The parameters are 13.5GB, prior to this change loading required 24.6GB, now 13.6GB was enough.

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful! Does it mean we can also simplify code in kino_bumblebee and avoid preallocate_params there too? Or we don't use preallocate_params there?

@jonatanklosko
Copy link
Member Author

Or we don't use preallocate_params there?

We don't :) Also, when Livebook adds EXLA it sets EXLA.Backend (not :host), so params will be loaded onto the GPU if available. I think :host can still be a good default for production usage, so that we are explicit about what runs on the GPU, and don't accidentally block other side computations, but defaulting to the GPU in most notebooks is fine, even more so with this change :D

@jonatanklosko jonatanklosko merged commit 78c7694 into main Feb 23, 2024
2 checks passed
@jonatanklosko jonatanklosko deleted the jk-lazy-params branch February 23, 2024 09:10
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants