Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Faster & memory-efficient logprobs calculation #583

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Dec 2, 2023

The current logprobs_of_labels computes logprobs using a log_softmax followed by a gather. When the input logits is not contiguous, the log_softmax will make a copy of the logits, which is very large (batch_size * seq_len * vocab_size can be 32 * 2048 * 64000 * 2B = 8GB for typical settings).

This PR directly feeds the contiguous logits into log_softmax so as to reduce the peak cuda memory and remove redundant copy.

Test script:

import torch
from torch.utils.benchmark import Timer
from trlx.utils.modeling import logprobs_of_labels

def perf():
    batch_size, seq_len, vocab_size = 32, 2048, 64000
    logits = torch.randn((batch_size, seq_len, vocab_size), dtype=torch.half, device='cuda')
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device='cuda')

    # correctness
    assert torch.allclose(logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:]), logprobs_of_labels(logits, input_ids[:, 1:]))

    # peak memory test
    torch.cuda.empty_cache()
    logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])
    print(f'original allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    torch.cuda.empty_cache()
    logprobs_of_labels(logits, input_ids[:, 1:])
    print(f'optimized allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    # speed test
    timer = Timer(stmt="logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_org = timer.timeit(100).mean
    print(f'original costs: {elapsed_org:.4f} s')

    timer = Timer(stmt="logprobs_of_labels(logits, input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_opt = timer.timeit(100).mean
    print(f'optimized costs: {elapsed_opt:.4f} s')

perf()

Tested on a Tesla V100, method in this PR is both faster (1.6x speedup) and memory-efficient.

original allocated: 8.389 GB, reserved: 25.164 GB
optimized allocated: 8.389 GB, reserved: 16.779 GB
original costs: 0.0700 s
optimized costs: 0.0435 s

@codecov-commenter
Copy link

codecov-commenter commented Dec 2, 2023

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (91a0f43) 43.58% compared to head (730d900) 43.58%.
Report is 1 commits behind head on main.

❗ Current head 730d900 differs from pull request most recent head aa1031a. Consider uploading reports for the commit aa1031a to get more accurate results

Files Patch % Lines
trlx/models/modeling_nemo_ppo.py 0.00% 3 Missing ⚠️
trlx/trainer/accelerate_ppo_trainer.py 57.14% 3 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #583   +/-   ##
=======================================
  Coverage   43.58%   43.58%           
=======================================
  Files          33       33           
  Lines        4974     4974           
=======================================
  Hits         2168     2168           
  Misses       2806     2806           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants