Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

TF: XLA logits processors - minimum length, forced eos, and forced bos #16912

Merged
merged 2 commits into from
Apr 25, 2022

Conversation

gante
Copy link
Member

@gante gante commented Apr 24, 2022

What does this PR do?

(Review after #16899)

A few more XLA-compatible logits processors -- minimum length, forced eos, and forced bos. Only the first one needed changes, mostly to avoid needless retracing (it actually compiled without changes but would trigger a retrace at iteration, which would be super slow).

After this PR, the only remaining processors are the bad words and ngrams ones.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 24, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I have a couple of nits that shouldn't actually affect the output, and one of them might be me being totally wrong anyway.

src/transformers/generation_tf_logits_process.py Outdated Show resolved Hide resolved
Comment on lines +227 to +228
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
self._apply_eos_token_mask(scores),
scores,

Would this work without the lambdas and identity call? I feel like it should but I'm not sure if I'm missing something obvious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it doesn't. Super unintuitive, but tf.cond expects a callable, not the output of each branch 😬 (docs)

It fails if we remove the lambda.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
@gante gante merged commit 809dac4 into huggingface:main Apr 25, 2022
@gante gante deleted the xla_min_len branch April 25, 2022 18:27
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
huggingface#16912)

* XLA min len, forced eos, and forced bos

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants