-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
TF: XLA logits processors - minimum length, forced eos, and forced bos #16912
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I have a couple of nits that shouldn't actually affect the output, and one of them might be me being totally wrong anyway.
lambda: self._apply_eos_token_mask(scores), | ||
lambda: tf.identity(scores), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lambda: self._apply_eos_token_mask(scores), | |
lambda: tf.identity(scores), | |
self._apply_eos_token_mask(scores), | |
scores, |
Would this work without the lambdas and identity call? I feel like it should but I'm not sure if I'm missing something obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, it doesn't. Super unintuitive, but tf.cond
expects a callable, not the output of each branch 😬 (docs)
It fails if we remove the lambda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
huggingface#16912) * XLA min len, forced eos, and forced bos Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
What does this PR do?
(Review after #16899)
A few more XLA-compatible logits processors -- minimum length, forced eos, and forced bos. Only the first one needed changes, mostly to avoid needless retracing (it actually compiled without changes but would trigger a retrace at iteration, which would be super slow).
After this PR, the only remaining processors are the bad words and ngrams ones.