Skip to content

(TF) model.generate to tf.function for tf serving #16823

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
piEsposito opened this issue Apr 18, 2022 · 22 comments · Fixed by #18372
Closed

(TF) model.generate to tf.function for tf serving #16823

piEsposito opened this issue Apr 18, 2022 · 22 comments · Fixed by #18372
Assignees
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@piEsposito
Copy link
Contributor

Feature request

It would be nice if you wrapped the generate method of autorregressive models into a tf.function. That way we could export and serve it with all the Tensorflow production stack.

Its kinda a revival of #5443.

It would enable us to do something like:

from transformers import AutoTokenizer, TFAutoModelForCausalLM
import tensorflow as tf

model = TFAutoModelForCausalLM.from_pretrained("gpt2")
model.save(
    "some_place",
    signatures={
        "serving_default": model.generate.get_concrete_function(tf.TensorSpec([None, None], tf.int32))
    }
)

And then serve it on TF production stack.

Motivation

It would be nice if you wrapped the generate method of autorregressive models into a tf.function. That way we could export and serve it with all the Tensorflow production stack.

It is frustrating to have to write generate by hand or move to PyTorch to serve generative language models.

Your contribution

I could write a PR, thou it would be nice if HF could share what they have done when trying it, as @Rocketknight1 and @patrickvonplaten said in : #5443 (comment)_ , so I would have somewhere to go from.

@patrickvonplaten
Copy link
Contributor

Hey @piEsposito,

The function should now be useable with tf.function I think. We don't want to wrap generate tf.function automatically ourselves, but you should be able to do the following now:

#!/usr/bin/env python3
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = TFGPT2LMHeadModel.from_pretrained("gpt2")

input_ids = tokenizer("hello there can you continue", return_tensors="tf").input_ids

xla_generate = tf.function(model.generate, jit_compile=True)
outputs = xla_generate(input_ids)

print("Output", tokenizer.batch_decode(outputs))

@patrickvonplaten
Copy link
Contributor

cc @gante

@gante
Copy link
Member

gante commented Apr 19, 2022

Hey @piEsposito 👋 As @patrickvonplaten mentioned, we have some generation functionality that can be wrapped by tf.function to be highly accelerated -- our tests point at a >30x speedup if an nVidia T4 is used.

The example provided should be functional and XLA-accelerated. However, some advanced features are not yet XLA-compatible, including:

  • accelerated serving of different lengths (changing input length triggers recompilation at the moment)
  • Beam Search (num_samples option in generate)
  • generate options like bad_words_ids or no_repeat_ngram_size

All these should be solved in the next 1-2 months. Keep an eye on our releases, and let us know if you run into problems :)

@piEsposito
Copy link
Contributor Author

Hey @gante , thanks for the quick reply.
Actually, my problem is specifically creating a serving signature that receives an input with variable length so I can use it with TF Serving in production. Do you have anything on that?

@gante
Copy link
Member

gante commented Apr 19, 2022

tf.function has a experimental_relax_shapes argument, which may help there. I can't confirm, as I haven't tested :) An alternative would be to pad all inputs to the maximum length accepted by the model, but that might spend needless memory/computing.

@piEsposito
Copy link
Contributor Author

@gante thanks. Do you know how can I use the generate method with the fully padded sequences? It always throws an error here :( .

@gante
Copy link
Member

gante commented Apr 19, 2022

Pardon me, I wrote a half-truth above :) For encoder-decoder (aka sequence to sequence) models like T5, you can do as I wrote above. For decoder-only models like gpt-2 you can left-pad to a constant length -- see this test as an example.

@piEsposito
Copy link
Contributor Author

Sorry, but still when I do pad it to max_length (if we set padding to True it won't pad the max accepted length) it throws me an error:

from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
import tensorflow as tf

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


model = TFGPT2LMHeadModel.from_pretrained("gpt2")

encoded_input = tokenizer([text],
                          return_tensors='tf',
                          padding="max_length")

model.generate(
    encoded_input.input_ids,
    max_length=1024
)

Throws me a:

ValueError: The context has 1024 number of tokens, but `max_length` is only 1024.

And of course I can't set max_length to anything more than 1024.

Am I doing something wrong?

@gante
Copy link
Member

gante commented Apr 19, 2022

The constant length in decoder-only models has to be smaller than max_length (as opposed to encoder-decoder models, where it can be padded to max_length), and the difference between your constant and generate's max_length corresponds to the maximum tokens generate can generate.

@piEsposito
Copy link
Contributor Author

When I pad and leave a few tokens for new generation, it still won't generate my text, but rather some random stuff after about 1000 eos tokens:

text = "Replace me by any text you'd like."

encoded_input = tokenizer([text],
                          return_tensors='tf',
                          padding="max_length")

preds = model.generate(
    encoded_input.input_ids[:, 50:],
    max_length=1024,
    pad_token_id=tokenizer.pad_token_id
)

tokenizer.batch_decode(preds)

And I get something like

[
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>...
Replace me by any text you'd like.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n"
]

This result stays the same even when I explicitly mask the padded tokens:

preds = model.generate(
    encoded_input.input_ids[:, 50:],
    max_length=1024,
    attention_mask=encoded_input.attention_mask[:,50:]
)

When we try with the same input and do greedy decoding it makes sense.

@piEsposito
Copy link
Contributor Author

piEsposito commented Apr 19, 2022

It seems to be related to

def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids = None
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
}

Where when we are not passing use_xla=True it will set the attention masks as None.

But it could be something else, as just passing use_xla as True changes the result but won't fix it.

@gante
Copy link
Member

gante commented Apr 20, 2022

@piEsposito it seems like we still have a couple of bugs to fix :D

I'm afraid I can't be of much further help -- I'm actively developing XLA + generate, but I don't expect to be able to sort your particular issue within the next month. The roadmap is approximatelly XLA logits processors -> XLA beam search -> efficient XLA batching (your issue) -> XLA on more models beyond GPT-2 and T5. When all this is sorted, we will make a big announcement and publish some tutorials. Until then, feel free to ping me to query the state of the XLA changes :)

@gante gante self-assigned this Apr 20, 2022
@piEsposito
Copy link
Contributor Author

@gante if you have an open-sourced branch I would love to help with that generate stuff. If not, thank you for your time and for trying to help me out with this.

@gante
Copy link
Member

gante commented Apr 20, 2022

@piEsposito that would be lovely :)

The step I will work next, as I mentioned above, is to make the logit processors XLA-compatible. In other words, rewrite them such that the tests here pass if you compile the function with tf.function(jit_compile=True). Some of them may already work -- feel free to claim one (or more) for you to work on, excluding the repetition_penalty (which I've already rewrote for XLA in a branch)

@piEsposito
Copy link
Contributor Author

@gante hacking Tensorflow away to make stuff serializable is kind of a hobby and also is paying my bills for a long time, so I can work on that.

I just need a bit more context:

  • How do I "claim" those logit-processors to work on?
  • Should I re-write those tests but using the compiled tf functions?
  • Can you point me to your branch to check how are you adding the new tests (to keep it the same style)?

Thanks, let´s do it.

@gante
Copy link
Member

gante commented Apr 21, 2022

Awesome @piEsposito! I will open a PR today, so you can have an example, and post here a more detailed guide 💪

@piEsposito
Copy link
Contributor Author

Thanks!

@gante
Copy link
Member

gante commented Apr 21, 2022

@piEsposito
This is the PR for an XLA-compatible repetition penalty logits processor. I've just opened it, so I'd suggest waiting until the review process is complete before starting on a new logit processor.

After the PR above gets approved, the process would be:

  • write here which logit processor you would like to work on, so we don't work on the same one (this is what I meant by "claim" :) );
  • write the XLA test, as in the PR linked above (feel free to make the tests stricter, as I did in the PR);
  • make modifications until it passes -- I suspect that a few of them are already XLA-compatible.

If you run into issues along the way, let me know. I will let you know here when the PR gets approved, so we can start on the next processors.

@gante
Copy link
Member

gante commented Apr 22, 2022

(The PR got approved and merged. Working on the TFLogitsWarper subclasses now.)

@piEsposito
Copy link
Contributor Author

(The PR got approved and merged. Working on the TFLogitsWarper subclasses now.)

Let's do it man.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented May 19, 2022

(beam search being worked on atm, last missing piece)

@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 8, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
3 participants