Allow for sharing of donated buffers between different pmap
-ed functions
#7144
Labels
enhancement
New feature or request
needs info
More information is required to diagnose & prioritize the issue.
I wanted to have two
pmap()
-ed function, piping one's output into another. For example, in the following way:parallel_train_step()
andparallel_eval_step()
are the twopmap()
-ed functions andstate
is the object being passed from one into another. The above code is adapted from this HuggingFace Notebook.Now, in case of small models, the
state
is small in size and the output ofparallel_train_step
orparallel_eval_step
gets stored in a buffer distinct from the input buffer. But in case of large models, the outputstate
may get stored in the input buffer itself. But the input buffer is marked as donated, so when the outputstate
is passed into the second function(parallel_eval_step()
), it raisesRuntimeError: Invalid argument: Buffer has been deleted or donated
.I notice that we can
print()
the returnedstate
, we can alsoserialize()
it, but can't pass the buffers to the secondpmap()
-ed function.Manually, I tried solving the above error by using
state = jax.tree_util.tree_map(lambda x: jnp.array(x.tolist()), state)
before passing it to the second function. Also, tried usinglambda x: jnp.array(x,copy=True)
. Both of which didn't work. Is there some other way to solve the problem?Possible solution can be to un-mark the buffer as
donated
when the input buffer is used to store the output. But how to do it?Related: #1733
Please:
The text was updated successfully, but these errors were encountered: