Skip to content

Emu3: add model #33770

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 61 commits into from
Jan 10, 2025
Merged

Emu3: add model #33770

merged 61 commits into from
Jan 10, 2025

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Sep 27, 2024

What does this PR do?

As per title. The code can work for generating text in single-batch scenarios but the generated text doesn't match input image. For batched generation, seems like the orig impl neither supports it mostly because image features from processor are returned with different shapes (smart resize to converse as much orig image size as possible). We can try to do padding similar to llava-next but I am not sure if will just work, I'll contact the authors

TODO:

  • Batched generation
  • Upload chat template and change the image-placeholder token from extra-0 to smth like <image>
  • Match the orig implementation on logit level
  • Tests, many more tests
  • Check out image generation and see how we can enable interleaved image+text generation as in Chameleon. Maybe not natively with transformers, but we can provide scripts with external libraries for structured generation -> not possible because text-generation and image-generation are two different checkpoints with different weights
from PIL import Image
import torch
import requests

from transformers import (
    Emu3Config,
    Emu3ForConditionalGeneration,
    Emu3ImageProcessor,
    Emu3Processor,
)

output_dir = "/raid/raushan/emu3"
processor = Emu3Processor.from_pretrained(output_dir)
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype="bfloat16", device_map="auto")
processor.tokenizer.padding_side = "left"

text = "You are a helpful assistant. USER: <|extra_0|>Please describe the image. ASSISTANT:"
image = Image.open("/raid/raushan/image.png")
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)

inputs = processor(
    text=[text, text],
    images=[image2, image],
    return_tensors="pt",
    padding=True,
)

inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)

out = model.generate(**inputs, max_new_tokens=100)
text_out = processor.batch_decode(out, skip_special_tokens=True)
print(text_out)

And for image generation:

from PIL import Image
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
import torch
import requests

from transformers import (
    Emu3Config,
    Emu3ForConditionalGeneration,
    Emu3ImageProcessor,
    Emu3Processor,
)

output_dir = "/raid/raushan/emu3-gen"
processor = Emu3Processor.from_pretrained(output_dir)
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype="bfloat16", device_map="auto", ) # attn_implementation="flash_attention_2",


inputs = processor(
    text=["a portrait of young girl. masterpiece, film grained, best quality.", "a dog running under the rain"],
    padding=True,
    return_tensors="pt",
    return_for_image_generation=True,
)
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)

image_sizes = inputs.pop("image_sizes")
HEIGHT, WIDTH = image_sizes[0]
VISUAL_TOKENS = model.model.vocabulary_mapping.image_tokens

def prefix_allowed_tokens_fn(batch_id, input_ids):
    height, width = HEIGHT, WIDTH
    visual_tokens = VISUAL_TOKENS
    image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device) # torch.tensor([processor.tokenizer.image_token_id], device=model.device)
    eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0] # torch.tensor([processor.tokenizer.eoi_token_id], device=model.device)
    eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0] # torch.tensor([processor.tokenizer.eos_token_id], device=model.device)
    pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0] # torch.tensor([processor.tokenizer.pad_token_id], device=model.device)
    eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
    eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0]

    position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0]
    offset = input_ids.shape[0] - position
    if offset % (width + 1) == 0:
        return (eol_token_id, )
    elif offset == (width + 1) * height + 1:
        return (eof_token_id, )
    elif offset == (width + 1) * height + 2:
        return (eoi_token_id, )
    elif offset == (width + 1) * height + 3:
        return (eos_token_id, )
    elif offset > (width + 1) * height + 3:
        return (pad_token_id, )
    else:
        return visual_tokens


out = model.generate(
    **inputs,
    max_new_tokens=50_000,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    do_sample=True,
    top_k=2048,
    return_dict_in_generate=True,
)

print(out.sequences.shape, inputs.input_ids.shape)

image = model.model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH)
images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image") # internally we convert to np but it's not supported in bf16 precision
for i, image in enumerate(images['pixel_values']):
    image.save(f"result_{i}.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 1382 to 1387
@add_start_docstrings(
"The Emu3 Text Model which consists of transformer with self attention layers.",
EMU3_START_DOCSTRING,
)
class Emu3TextModel(Emu3PreTrainedModel):
config_class = Emu3TextConfig
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding LlamaModel to bases messes up the auto-generated modeling file by adding new classes like Emu3TextAttention and so on, while we have Emu3Attention

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be solved by #34487!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To try again!

@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented Oct 28, 2024

I think this is ready for review. @ArthurZucker will you be reviewing or is there anyone I can tag for initial review?

Btw, the repo consistency tests will fail because the modular doesn't import EmuTextConfig. I found that the modular imports all the things specified in module-file import section + all things in old-model-file import section. But Emu3TextConfig is in neither of them, so prob we need also to check imports between one model files. I'll think how to fix that

@ArthurZucker
Copy link
Collaborator

You can tag @Cyrilvallez !

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, great work! With the new modular version #34487, I think we can still improve a bit! Should be merged very soon, but this is already very nice imo if you don't want to wait 🤗

Comment on lines 1382 to 1387
@add_start_docstrings(
"The Emu3 Text Model which consists of transformer with self attention layers.",
EMU3_START_DOCSTRING,
)
class Emu3TextModel(Emu3PreTrainedModel):
config_class = Emu3TextConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be solved by #34487!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for the updates regarding @Cyrilvallez 's PR, will review again once updated

@qubvel
Copy link
Member

qubvel commented Jan 8, 2025

heh, is something wrong with code owners?

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :)

@zucchini-nlp
Copy link
Member Author

Yeah, seems like it automatically tags all code owners depending on files touched/created...

@ArthurZucker, would be nice to not tag that many people at once

zucchini-nlp and others added 8 commits January 9, 2025 10:56
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks for iterating! My only comment is that I have not personnaly looked enough at the MIMI or the VQVAE from Chameleon you would know better, but the more standard the better!
A few nits but good to go IMO.


# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to have some expected outputs!

@zucchini-nlp
Copy link
Member Author

Let's merge 🚀

@zucchini-nlp zucchini-nlp merged commit 52e1f87 into huggingface:main Jan 10, 2025
25 checks passed
ArthurZucker added a commit that referenced this pull request Jan 10, 2025
* model can convert to HF and be loaded back

* nit

* works in single batch generation but hallucinates

* use the image tokens

* add image generation

* now it works

* add tests

* update

* add modulare but it doesn't work for porting docstring :(

* skip some tests

* add slow tests

* modular removed the import?

* guess this works

* update

* update

* fix copies

* fix test

* fix copies

* update

* docs

* fix tests

* last fix tests?

* pls

* repo consistency

* more style

* style

* remove file

* address comments

* tiny bits

* update after the new modular

* fix tests

* add one more cond in check attributes

* decompose down/up/mid blocks

* allow static cache generation in VLMs

* nit

* fix copies

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/emu3.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix VAE upsampling

* Update src/transformers/models/emu3/modular_emu3.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments

* state overwritten stuff explicitly

* fix copies

* add the flag for flex attn

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@zucchini-nlp zucchini-nlp changed the title [WIP] Emu3: add model Emu3: add model Jan 13, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants