Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow for sharing of donated buffers between different pmap-ed functions #7144

Closed
2 tasks done
Jeevesh8 opened this issue Jun 30, 2021 · 3 comments
Closed
2 tasks done
Labels
enhancement New feature or request needs info More information is required to diagnose & prioritize the issue.

Comments

@Jeevesh8
Copy link

Jeevesh8 commented Jun 30, 2021

I wanted to have two pmap()-ed function, piping one's output into another. For example, in the following way:

for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
        state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)

    # evaluate
    for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
          predictions = parallel_eval_step(state, batch)
          metric.add_batch(predictions=chain(*predictions), references=chain(*labels))

    eval_metric = metric.compute()

parallel_train_step() and parallel_eval_step() are the two pmap()-ed functions and state is the object being passed from one into another. The above code is adapted from this HuggingFace Notebook.

Now, in case of small models, the state is small in size and the output of parallel_train_step or parallel_eval_step gets stored in a buffer distinct from the input buffer. But in case of large models, the output state may get stored in the input buffer itself. But the input buffer is marked as donated, so when the output state is passed into the second function(parallel_eval_step()), it raises RuntimeError: Invalid argument: Buffer has been deleted or donated .

I notice that we can print() the returned state, we can also serialize() it, but can't pass the buffers to the second pmap()-ed function.

Manually, I tried solving the above error by using state = jax.tree_util.tree_map(lambda x: jnp.array(x.tolist()), state) before passing it to the second function. Also, tried using lambda x: jnp.array(x,copy=True). Both of which didn't work. Is there some other way to solve the problem?

Possible solution can be to un-mark the buffer as donated when the input buffer is used to store the output. But how to do it?

Related: #1733


Please:

@Jeevesh8 Jeevesh8 added the enhancement New feature or request label Jun 30, 2021
@skye
Copy link
Member

skye commented Jun 30, 2021

I'm a bit confused -- the input state buffer is donated, but then parallel_train_step returns a new buffer that is again assigned to state, so you shouldn't get that Buffer has been deleted or donated error. Are you able to provide a minimal repro we can run that shows the issue?

@skye skye added the needs info More information is required to diagnose & prioritize the issue. label Jun 30, 2021
@Jeevesh8
Copy link
Author

Jeevesh8 commented Jul 1, 2021

I am really sorry for the confusion and trouble.
I found the source of error. You are indeed correct that the output buffer is always distinct from the input one.

@Jeevesh8 Jeevesh8 closed this as completed Jul 1, 2021
@skye
Copy link
Member

skye commented Jul 1, 2021

No worries, thanks for the update!

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request needs info More information is required to diagnose & prioritize the issue.
Projects
None yet
Development

No branches or pull requests

2 participants