-
Notifications
You must be signed in to change notification settings - Fork 28.6k
(TF) model.generate to tf.function for tf serving #16823
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Comments
Hey @piEsposito, The function should now be useable with #!/usr/bin/env python3
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tokenizer("hello there can you continue", return_tensors="tf").input_ids
xla_generate = tf.function(model.generate, jit_compile=True)
outputs = xla_generate(input_ids)
print("Output", tokenizer.batch_decode(outputs)) |
cc @gante |
Hey @piEsposito 👋 As @patrickvonplaten mentioned, we have some generation functionality that can be wrapped by The example provided should be functional and XLA-accelerated. However, some advanced features are not yet XLA-compatible, including:
All these should be solved in the next 1-2 months. Keep an eye on our releases, and let us know if you run into problems :) |
Hey @gante , thanks for the quick reply. |
|
@gante thanks. Do you know how can I use the generate method with the fully padded sequences? It always throws an error here :( . |
Pardon me, I wrote a half-truth above :) For encoder-decoder (aka sequence to sequence) models like T5, you can do as I wrote above. For decoder-only models like gpt-2 you can left-pad to a constant length -- see this test as an example. |
Sorry, but still when I do pad it to from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
import tensorflow as tf
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
encoded_input = tokenizer([text],
return_tensors='tf',
padding="max_length")
model.generate(
encoded_input.input_ids,
max_length=1024
) Throws me a:
And of course I can't set Am I doing something wrong? |
The constant length in decoder-only models has to be smaller than |
When I pad and leave a few tokens for new generation, it still won't generate my text, but rather some random stuff after about 1000 eos tokens: text = "Replace me by any text you'd like."
encoded_input = tokenizer([text],
return_tensors='tf',
padding="max_length")
preds = model.generate(
encoded_input.input_ids[:, 50:],
max_length=1024,
pad_token_id=tokenizer.pad_token_id
)
tokenizer.batch_decode(preds) And I get something like
This result stays the same even when I explicitly mask the padded tokens: preds = model.generate(
encoded_input.input_ids[:, 50:],
max_length=1024,
attention_mask=encoded_input.attention_mask[:,50:]
) When we try with the same input and do greedy decoding it makes sense. |
It seems to be related to transformers/src/transformers/models/gpt2/modeling_tf_gpt2.py Lines 816 to 842 in 3104036
Where when we are not passing But it could be something else, as just passing use_xla as True changes the result but won't fix it. |
@piEsposito it seems like we still have a couple of bugs to fix :D I'm afraid I can't be of much further help -- I'm actively developing XLA + |
@gante if you have an open-sourced branch I would love to help with that generate stuff. If not, thank you for your time and for trying to help me out with this. |
@piEsposito that would be lovely :) The step I will work next, as I mentioned above, is to make the logit processors XLA-compatible. In other words, rewrite them such that the tests here pass if you compile the function with |
@gante hacking Tensorflow away to make stuff serializable is kind of a hobby and also is paying my bills for a long time, so I can work on that. I just need a bit more context:
Thanks, let´s do it. |
Awesome @piEsposito! I will open a PR today, so you can have an example, and post here a more detailed guide 💪 |
Thanks! |
@piEsposito After the PR above gets approved, the process would be:
If you run into issues along the way, let me know. I will let you know here when the PR gets approved, so we can start on the next processors. |
(The PR got approved and merged. Working on the |
Let's do it man. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
(beam search being worked on atm, last missing piece) |
Feature request
It would be nice if you wrapped the generate method of autorregressive models into a
tf.function
. That way we could export and serve it with all the Tensorflow production stack.Its kinda a revival of #5443.
It would enable us to do something like:
And then serve it on TF production stack.
Motivation
It would be nice if you wrapped the generate method of autorregressive models into a
tf.function
. That way we could export and serve it with all the Tensorflow production stack.It is frustrating to have to write generate by hand or move to PyTorch to serve generative language models.
Your contribution
I could write a PR, thou it would be nice if HF could share what they have done when trying it, as @Rocketknight1 and @patrickvonplaten said in : #5443 (comment)_ , so I would have somewhere to go from.
The text was updated successfully, but these errors were encountered: