Skip to content

On Device Sampling #440

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged

Conversation

quic-sanising
Copy link
Contributor

@quic-sanising quic-sanising commented Jun 11, 2025

✨ On Device Sampling

📌 Overview

This PR introduces On Device Sampling for QEffForCausalLM models, enabling sampling operations to be executed directly on the QAIC device rather than the host CPU. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability.



🚀 Motivation

Traditionally, sampling (e.g., greedy, top-k, top-p) is performed on the host CPU after logits are returned from the device. This approach incurs:

  • High PCIe traffic due to large logits tensors [batch_size, vocab_size]
  • Latency bottlenecks from CPU-bound sampling logic
  • Limited scalability due to CPU thread constraints

On Device Sampling addresses these issues by:

  • Performing sampling directly on the QAIC device
  • Returning only the selected next tokens [batch_size, 1]
  • Leveraging the device’s parallelism and optimized compute paths


⚙️ Supported Sampling Strategies

The following sampling techniques are now supported natively on the QAIC device:

  1. Repetition Penalty​: Penalize tokens that have appeared in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.​

  2. Presence Penalty​: Penalize tokens that are present in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.

  3. Temperature Scaling​: Adjust the sharpness of the logits distribution. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.​

  4. Top K​: Sample from the k largest tokens by value.

  5. Top P​: Sample from the smallest set of tokens whose cumulative probability is greater than or equal to p.​ Must be in (0, 1]. Set to 1 to consider all tokens.​

  6. Min P​: Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.​

  7. Greedy Sampling​: Choose the token with highest value.

  8. Random Sampling: Choose a token randomly with its probability of being chosen given by its value.



🛠️ Implementation Details

  • Sampler Integration: Sampling logic is injected via include_sampler=True during model loading. No changes to the model architecture are required.

  • Memory Optimization: Two scratch buffers of shape [batch_size, vocab_size] are used to track token occurrences for applying repetition and presence penalties efficiently on-device.

  • Performance Gains:

    • Reduced PCIe traffic (logits → tokens)
    • Higher throughput via device-level parallelism
    • Scalable to 64+ concurrent inference streams


🧪 Usage

from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM

# Load model with On Device Sampler enabled
qeff_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    include_sampler=True,
    return_pdfs=False,
)

# Compile as usual
qeff_model.compile(
    prefill_seq_length=128,
    ctx_len=256,
    full_batch_size=16,
    num_devices=4,
    num_speculative_tokens=0,
    mxint8_kv_cache=True,
    mxfp6_matmul=True,
)

To disable On Device Sampling and revert to host-side sampling, simply set include_sampler=False.

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-sanising quic-sanising mentioned this pull request Jun 11, 2025
@quic-sanising quic-sanising marked this pull request as ready for review June 11, 2025 05:18
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-rishinr quic-rishinr changed the base branch from main to release/v1.20.0_dev June 12, 2025 03:46
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-rishinr quic-rishinr merged commit 350a859 into quic:release/v1.20.0_dev Jun 12, 2025
3 checks passed
quic-rishinr pushed a commit that referenced this pull request Jun 12, 2025
## ✨ On Device Sampling

### 📌 Overview

This PR introduces **On Device Sampling** for `QEffForCausalLM` models,
enabling sampling operations to be executed directly on the **QAIC
device** rather than the host CPU. This enhancement significantly
reduces host-device communication overhead and improves inference
throughput and scalability.

</br>
</br>

### 🚀 Motivation

Traditionally, sampling (e.g., greedy, top-k, top-p) is performed on the
host CPU after logits are returned from the device. This approach
incurs:

- High PCIe traffic due to large logits tensors `[batch_size,
vocab_size]`
- Latency bottlenecks from CPU-bound sampling logic
- Limited scalability due to CPU thread constraints

**On Device Sampling** addresses these issues by:

- Performing sampling directly on the QAIC device
- Returning only the selected next tokens `[batch_size, 1]`
- Leveraging the device’s parallelism and optimized compute paths

</br>
</br>

### ⚙️ Supported Sampling Strategies

The following sampling techniques are now supported natively on the QAIC
device:

1. **Repetition Penalty**​: Penalize tokens that have appeared in the
prompt and the generated text so far. Values > 1 encourage the model to
use new tokens, while values < 1 encourage the model to repeat tokens.​

2. **Presence Penalty**​: Penalize tokens that are present in the
generated text so far. Values > 0 encourage the model to use new tokens,
while values < 0 encourage the model to repeat tokens.

3. **Temperature Scaling​**: Adjust the sharpness of the logits
distribution. Lower values make the model more deterministic, while
higher values make the model more random. Zero means greedy sampling.​

4. **Top K​**: Sample from the `k` largest tokens by value.

5. **Top P​**: Sample from the smallest set of tokens whose cumulative
probability is greater than or equal to `p`.​ Must be in (0, 1]. Set to
1 to consider all tokens.​

6. **Min P​**: Represents the minimum probability for a token to be
considered, relative to the probability of the most likely token. Must
be in [0, 1]. Set to 0 to disable this.​

7. **Greedy Sampling​**: Choose the token with highest value.

8. **Random Sampling**: Choose a token randomly with its probability of
being chosen given by its value.

</br>
</br>

### 🛠️ Implementation Details

- **Sampler Integration**: Sampling logic is injected via
`include_sampler=True` during model loading. No changes to the model
architecture are required.

- **Memory Optimization**: Two scratch buffers of shape `[batch_size,
vocab_size]` are used to track token occurrences for applying repetition
and presence penalties efficiently on-device.

- **Performance Gains**:
  - Reduced PCIe traffic (logits → tokens)
  - Higher throughput via device-level parallelism
  - Scalable to 64+ concurrent inference streams

</br>
</br>

### 🧪 Usage

```python
from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM

# Load model with On Device Sampler enabled
qeff_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    include_sampler=True,
    return_pdfs=False,
)

# Compile as usual
qeff_model.compile(
    prefill_seq_length=128,
    ctx_len=256,
    full_batch_size=16,
    num_devices=4,
    num_speculative_tokens=0,
    mxint8_kv_cache=True,
    mxfp6_matmul=True,
)
```

To disable On Device Sampling and revert to host-side sampling, simply
set `include_sampler=False`.

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>
@quic-hemagnih
Copy link
Contributor

Closing this PR as changes of this PR is migrated to another PR: #447

@quic-sanising quic-sanising deleted the on-device-sampling-pr branch June 18, 2025 17:04
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants