-
Notifications
You must be signed in to change notification settings - Fork 43
On Device Sampling #350
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
On Device Sampling #350
Conversation
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
9fab549
to
3b63ecb
Compare
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
@@ -75,7 +76,7 @@ def __repr__(self) -> str: | |||
|
|||
@classmethod | |||
@with_replaced_quantizers | |||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): | |||
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make them optional parameters
@@ -1317,8 +1322,14 @@ def __init__( | |||
if is_tlm: | |||
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch | |||
self.model, transformed = SpDTransform.apply(self.model) | |||
self.model.return_pdfs = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the code for handling is_tlm == FALSE condition for population of return_pdfs
dynamic_axes["top_ks"] = {0: "batch_size"} | ||
|
||
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80 | ||
dynamic_axes["top_ps"] = {0: "batch_size"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define constants for 0.80 and 0.99
|
||
@classmethod | ||
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: | ||
transformed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add doc string
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
* Initial commit * Reformat code * Fix bug * Add Gumbel-Max trick based random sampling * Bring up to date * Use Gumbel-Max Trick based Random Sampling as default * Clip k to max value * Add docstring for sampling parameters * Fix bug * Add support for continuous batching * Fix ONNX error for batch_size 1 treated as a Constant * Undo docstring deletion * Remove device and unncessary reshapes * Revert batch_size to 1 * Remove vocab_size from dynamic axes * Change condition * Change size of each sampling parameter to (batch_size, 1) * Reformat code * Add optimizations * Identify optimizations * Fix bug * Fix merge issue * Optimizations: Perform random sampling only on topk_values_asc Only need logits for probs when self.return_pdfs is True * Remove where clause for temperature * Remove boolean type casting for retain state * Always return next_tokens * Fix bug * Reformat code * Initialize retain states * Optimize imports * Remove torch.index_select() * Change dtype of penalty buffers to bool --------- Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
50953e2
to
d48d084
Compare
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
No description provided.