-
Notifications
You must be signed in to change notification settings - Fork 45
On Device Sampling #440
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Merged
quic-rishinr
merged 3 commits into
quic:release/v1.20.0_dev
from
quic-sanising:on-device-sampling-pr
Jun 12, 2025
Merged
On Device Sampling #440
quic-rishinr
merged 3 commits into
quic:release/v1.20.0_dev
from
quic-sanising:on-device-sampling-pr
Jun 12, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Closed
quic-hemagnih
approved these changes
Jun 11, 2025
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
quic-rishinr
pushed a commit
that referenced
this pull request
Jun 12, 2025
## ✨ On Device Sampling ### 📌 Overview This PR introduces **On Device Sampling** for `QEffForCausalLM` models, enabling sampling operations to be executed directly on the **QAIC device** rather than the host CPU. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. </br> </br> ### 🚀 Motivation Traditionally, sampling (e.g., greedy, top-k, top-p) is performed on the host CPU after logits are returned from the device. This approach incurs: - High PCIe traffic due to large logits tensors `[batch_size, vocab_size]` - Latency bottlenecks from CPU-bound sampling logic - Limited scalability due to CPU thread constraints **On Device Sampling** addresses these issues by: - Performing sampling directly on the QAIC device - Returning only the selected next tokens `[batch_size, 1]` - Leveraging the device’s parallelism and optimized compute paths </br> </br> ### ⚙️ Supported Sampling Strategies The following sampling techniques are now supported natively on the QAIC device: 1. **Repetition Penalty**: Penalize tokens that have appeared in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens. 2. **Presence Penalty**: Penalize tokens that are present in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. 3. **Temperature Scaling**: Adjust the sharpness of the logits distribution. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling. 4. **Top K**: Sample from the `k` largest tokens by value. 5. **Top P**: Sample from the smallest set of tokens whose cumulative probability is greater than or equal to `p`. Must be in (0, 1]. Set to 1 to consider all tokens. 6. **Min P**: Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. 7. **Greedy Sampling**: Choose the token with highest value. 8. **Random Sampling**: Choose a token randomly with its probability of being chosen given by its value. </br> </br> ### 🛠️ Implementation Details - **Sampler Integration**: Sampling logic is injected via `include_sampler=True` during model loading. No changes to the model architecture are required. - **Memory Optimization**: Two scratch buffers of shape `[batch_size, vocab_size]` are used to track token occurrences for applying repetition and presence penalties efficiently on-device. - **Performance Gains**: - Reduced PCIe traffic (logits → tokens) - Higher throughput via device-level parallelism - Scalable to 64+ concurrent inference streams </br> </br> ### 🧪 Usage ```python from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B", include_sampler=True, return_pdfs=False, ) # Compile as usual qeff_model.compile( prefill_seq_length=128, ctx_len=256, full_batch_size=16, num_devices=4, num_speculative_tokens=0, mxint8_kv_cache=True, mxfp6_matmul=True, ) ``` To disable On Device Sampling and revert to host-side sampling, simply set `include_sampler=False`. Signed-off-by: quic-sanising <quic_sanising@quicinc.com> Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>
Closing this PR as changes of this PR is migrated to another PR: #447 |
# for free
to join this conversation on GitHub.
Already have an account?
# to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
✨ On Device Sampling
📌 Overview
This PR introduces On Device Sampling for
QEffForCausalLM
models, enabling sampling operations to be executed directly on the QAIC device rather than the host CPU. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability.🚀 Motivation
Traditionally, sampling (e.g., greedy, top-k, top-p) is performed on the host CPU after logits are returned from the device. This approach incurs:
[batch_size, vocab_size]
On Device Sampling addresses these issues by:
[batch_size, 1]
⚙️ Supported Sampling Strategies
The following sampling techniques are now supported natively on the QAIC device:
Repetition Penalty: Penalize tokens that have appeared in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.
Presence Penalty: Penalize tokens that are present in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
Temperature Scaling: Adjust the sharpness of the logits distribution. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.
Top K: Sample from the
k
largest tokens by value.Top P: Sample from the smallest set of tokens whose cumulative probability is greater than or equal to
p
. Must be in (0, 1]. Set to 1 to consider all tokens.Min P: Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.
Greedy Sampling: Choose the token with highest value.
Random Sampling: Choose a token randomly with its probability of being chosen given by its value.
🛠️ Implementation Details
Sampler Integration: Sampling logic is injected via
include_sampler=True
during model loading. No changes to the model architecture are required.Memory Optimization: Two scratch buffers of shape
[batch_size, vocab_size]
are used to track token occurrences for applying repetition and presence penalties efficiently on-device.Performance Gains:
🧪 Usage
To disable On Device Sampling and revert to host-side sampling, simply set
include_sampler=False
.