Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Buffer donation to a jit function on GPU #1273

Closed
romanngg opened this issue Aug 30, 2019 · 7 comments
Closed

Buffer donation to a jit function on GPU #1273

romanngg opened this issue Aug 30, 2019 · 7 comments
Labels
enhancement New feature or request

Comments

@romanngg
Copy link
Contributor

Below is a CNN iteratedly applied to a 2Gb input. It produces a 4x2Gb = 8 Gb peak memory consumption.

import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit

@jit
def f(x):
  for _ in range(10):
    x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME', 
                                 dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
  return x

x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))  
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)

Without JIT, the peak memory consumption is 2x2Gb = 4 Gb, as is expected.

Would be great to achieve a comparable memory usage with JIT by input buffer donation to the jit function (not sure on the exact terminology).

Thanks a lot!

@hawkinsp
Copy link
Collaborator

Buffer donation has been checked in!

@romanngg
Copy link
Contributor Author

Thanks Peter, do you know how can I leverage it to reduce the memory consumption in the example above?

So far, even if I do

f = jit(f, donate_argnums=0)

I still get peak memory of 4x2 = 8Gb, and a message

jax/interpreters/xla.py:660: UserWarning: Some donated buffers were not usable: f32[524288,32,32,1]{3,2,1,0}

@jekbradbury
Copy link
Contributor

I believe that means that there wasn't an output with the same shape that could have reused that buffer (or there weren't an equal number of such outputs as inputs).

@romanngg
Copy link
Contributor Author

Interesting - how come it doesn't work in this example then? From my understanding here there's 1 input, 1 output, both of shape and type f32[524288,32,32,1].

@tomhennigan
Copy link
Collaborator

FYI buffeer donation is only supported on TPU at the moment, XLA team are working to support this on CPU/GPU but that may be why we cannot use the donation.

@romanngg romanngg changed the title Buffer donation to a jit function Buffer donation to a jit function on GPU Jun 30, 2020
@romanngg
Copy link
Contributor Author

I see, thanks! Could you please reopen this issue then?

@tomhennigan tomhennigan reopened this Jun 30, 2020
@hawkinsp
Copy link
Collaborator

Fixed by #3800

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants