Skip to content

On Device Sampling #350

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft

Conversation

quic-sanising
Copy link

No description provided.

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-sanising quic-sanising marked this pull request as ready for review April 9, 2025 04:48
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@quic-amitraj quic-amitraj marked this pull request as draft April 11, 2025 08:49
@@ -75,7 +76,7 @@ def __repr__(self) -> str:

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make them optional parameters

@@ -1317,8 +1322,14 @@ def __init__(
if is_tlm:
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
self.model, transformed = SpDTransform.apply(self.model)
self.model.return_pdfs = True
Copy link
Contributor

@quic-hemagnih quic-hemagnih Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the code for handling is_tlm == FALSE condition for population of return_pdfs

dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80
dynamic_axes["top_ps"] = {0: "batch_size"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define constants for 0.80 and 0.99


@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add doc string

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
* Initial commit

* Reformat code

* Fix bug

* Add Gumbel-Max trick based random sampling

* Bring up to date

* Use Gumbel-Max Trick based Random Sampling as default

* Clip k to max value

* Add docstring for sampling parameters

* Fix bug

* Add support for continuous batching

* Fix ONNX error for batch_size 1 treated as a Constant

* Undo docstring deletion

* Remove device and unncessary reshapes

* Revert batch_size to 1

* Remove vocab_size from dynamic axes

* Change condition

* Change size of each sampling parameter to (batch_size, 1)

* Reformat code

* Add optimizations

* Identify optimizations

* Fix bug

* Fix merge issue

* Optimizations:
Perform random sampling only on topk_values_asc
Only need logits for probs when self.return_pdfs is True

* Remove where clause for temperature

* Remove boolean type casting for retain state

* Always return next_tokens

* Fix bug

* Reformat code

* Initialize retain states

* Optimize imports

* Remove torch.index_select()

* Change dtype of penalty buffers to bool

---------

Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants