Skip to content

Commit 9838867

Browse files
PommesPeterzhuole1025yiyixuxu
authored
[Alpha-VLLM Team] Add Lumina-T2X to diffusers (#8652)
--------- Co-authored-by: zhuole1025 <zhuole1025@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 9e9ed35 commit 9838867

22 files changed

+2478
-17
lines changed

docs/source/en/_toctree.yml

+6
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249
title: DiTTransformer2DModel
250250
- local: api/models/hunyuan_transformer2d
251251
title: HunyuanDiT2DModel
252+
- local: api/models/lumina_nextdit2d
253+
title: LuminaNextDiT2DModel
252254
- local: api/models/transformer_temporal
253255
title: TransformerTemporalModel
254256
- local: api/models/sd3_transformer2d
@@ -324,6 +326,8 @@
324326
title: Latent Diffusion
325327
- local: api/pipelines/ledits_pp
326328
title: LEDITS++
329+
- local: api/pipelines/lumina
330+
title: Lumina-T2X
327331
- local: api/pipelines/marigold
328332
title: Marigold
329333
- local: api/pipelines/panorama
@@ -435,6 +439,8 @@
435439
title: EulerDiscreteScheduler
436440
- local: api/schedulers/flow_match_euler_discrete
437441
title: FlowMatchEulerDiscreteScheduler
442+
- local: api/schedulers/flow_match_heun_discrete
443+
title: FlowMatchHeunDiscreteScheduler
438444
- local: api/schedulers/heun
439445
title: HeunDiscreteScheduler
440446
- local: api/schedulers/ipndm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# LuminaNextDiT2DModel
14+
15+
A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).
16+
17+
## LuminaNextDiT2DModel
18+
19+
[[autodoc]] LuminaNextDiT2DModel
20+
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Lumina-T2X
14+
![concepts](https://github.com/Alpha-VLLM/Lumina-T2X/assets/54879512/9f52eabb-07dc-4881-8257-6d8a5f2a0a5a)
15+
16+
[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
17+
18+
The abstract from the paper is:
19+
20+
*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*
21+
22+
**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.
23+
24+
Lumina-Next has the following components:
25+
* It improves sampling efficiency with fewer and faster Steps.
26+
* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.
27+
* It uses a Frequency- and Time-Aware Scaled RoPE.
28+
29+
---
30+
31+
[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
32+
33+
The abstract from the paper is:
34+
35+
*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*
36+
37+
38+
You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).
39+
40+
**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.
41+
42+
Lumina-T2X has the following components:
43+
* It uses a Flow-based Large Diffusion Transformer as the backbone
44+
* It supports different any modalities with one backbone and corresponding encoder, decoder.
45+
46+
<Tip>
47+
48+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
49+
50+
</Tip>
51+
52+
### Inference (Text-to-Image)
53+
54+
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
55+
56+
First, load the pipeline:
57+
58+
```python
59+
from diffusers import LuminaText2ImgPipeline
60+
import torch
61+
62+
pipeline = LuminaText2ImgPipeline.from_pretrained(
63+
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
64+
).to("cuda")
65+
```
66+
67+
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
68+
69+
```python
70+
pipeline.transformer.to(memory_format=torch.channels_last)
71+
pipeline.vae.to(memory_format=torch.channels_last)
72+
```
73+
74+
Finally, compile the components and run inference:
75+
76+
```python
77+
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
78+
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
79+
80+
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
81+
```
82+
83+
## LuminaText2ImgPipeline
84+
85+
[[autodoc]] LuminaText2ImgPipeline
86+
- all
87+
- __call__
88+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# FlowMatchHeunDiscreteScheduler
14+
15+
`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206).
16+
17+
## FlowMatchHeunDiscreteScheduler
18+
[[autodoc]] FlowMatchHeunDiscreteScheduler
+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from safetensors.torch import load_file
6+
from transformers import AutoModel, AutoTokenizer
7+
8+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
9+
10+
11+
def main(args):
12+
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
13+
all_sd = load_file(args.origin_ckpt_path, device="cpu")
14+
converted_state_dict = {}
15+
# pad token
16+
converted_state_dict["pad_token"] = all_sd["pad_token"]
17+
18+
# patch embed
19+
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
20+
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
21+
22+
# time and caption embed
23+
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
24+
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
25+
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
26+
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
27+
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
28+
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
29+
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
30+
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
31+
32+
for i in range(24):
33+
# adaln
34+
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
35+
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
36+
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
37+
38+
# qkv
39+
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
40+
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
41+
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
42+
43+
# cap
44+
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
45+
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
46+
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
47+
48+
# output
49+
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
50+
51+
# attention
52+
# qk norm
53+
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
54+
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
55+
56+
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
57+
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
58+
59+
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
60+
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
61+
62+
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
63+
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
64+
65+
# attention norm
66+
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
67+
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
68+
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
69+
70+
# feed forward
71+
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
72+
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
73+
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
74+
75+
# feed forward norm
76+
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
77+
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
78+
79+
# final layer
80+
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
81+
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
82+
83+
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
84+
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
85+
86+
# Lumina-Next-SFT 2B
87+
transformer = LuminaNextDiT2DModel(
88+
sample_size=128,
89+
patch_size=2,
90+
in_channels=4,
91+
hidden_size=2304,
92+
num_layers=24,
93+
num_attention_heads=32,
94+
num_kv_heads=8,
95+
multiple_of=256,
96+
ffn_dim_multiplier=None,
97+
norm_eps=1e-5,
98+
learn_sigma=True,
99+
qk_norm=True,
100+
cross_attention_dim=2048,
101+
scaling_factor=1.0,
102+
)
103+
transformer.load_state_dict(converted_state_dict, strict=True)
104+
105+
num_model_params = sum(p.numel() for p in transformer.parameters())
106+
print(f"Total number of transformer parameters: {num_model_params}")
107+
108+
if args.only_transformer:
109+
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
110+
else:
111+
scheduler = FlowMatchEulerDiscreteScheduler()
112+
113+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
114+
115+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
116+
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
117+
118+
pipeline = LuminaText2ImgPipeline(
119+
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
120+
)
121+
pipeline.save_pretrained(args.dump_path)
122+
123+
124+
if __name__ == "__main__":
125+
parser = argparse.ArgumentParser()
126+
127+
parser.add_argument(
128+
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
129+
)
130+
parser.add_argument(
131+
"--image_size",
132+
default=1024,
133+
type=int,
134+
choices=[256, 512, 1024],
135+
required=False,
136+
help="Image size of pretrained model, either 512 or 1024.",
137+
)
138+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
139+
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
140+
141+
args = parser.parse_args()
142+
main(args)

src/diffusers/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
"HunyuanDiT2DMultiControlNetModel",
8989
"I2VGenXLUNet",
9090
"Kandinsky3UNet",
91+
"LuminaNextDiT2DModel",
9192
"ModelMixin",
9293
"MotionAdapter",
9394
"MultiAdapter",
@@ -162,6 +163,7 @@
162163
"EulerAncestralDiscreteScheduler",
163164
"EulerDiscreteScheduler",
164165
"FlowMatchEulerDiscreteScheduler",
166+
"FlowMatchHeunDiscreteScheduler",
165167
"HeunDiscreteScheduler",
166168
"IPNDMScheduler",
167169
"KarrasVeScheduler",
@@ -270,6 +272,7 @@
270272
"LDMTextToImagePipeline",
271273
"LEditsPPPipelineStableDiffusion",
272274
"LEditsPPPipelineStableDiffusionXL",
275+
"LuminaText2ImgPipeline",
273276
"MarigoldDepthPipeline",
274277
"MarigoldNormalsPipeline",
275278
"MusicLDMPipeline",
@@ -509,6 +512,7 @@
509512
HunyuanDiT2DMultiControlNetModel,
510513
I2VGenXLUNet,
511514
Kandinsky3UNet,
515+
LuminaNextDiT2DModel,
512516
ModelMixin,
513517
MotionAdapter,
514518
MultiAdapter,
@@ -580,6 +584,7 @@
580584
EulerAncestralDiscreteScheduler,
581585
EulerDiscreteScheduler,
582586
FlowMatchEulerDiscreteScheduler,
587+
FlowMatchHeunDiscreteScheduler,
583588
HeunDiscreteScheduler,
584589
IPNDMScheduler,
585590
KarrasVeScheduler,
@@ -669,6 +674,7 @@
669674
LDMTextToImagePipeline,
670675
LEditsPPPipelineStableDiffusion,
671676
LEditsPPPipelineStableDiffusionXL,
677+
LuminaText2ImgPipeline,
672678
MarigoldDepthPipeline,
673679
MarigoldNormalsPipeline,
674680
MusicLDMPipeline,

src/diffusers/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
4242
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
4343
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
44+
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
4445
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
4546
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
4647
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
@@ -85,6 +86,7 @@
8586
DiTTransformer2DModel,
8687
DualTransformer2DModel,
8788
HunyuanDiT2DModel,
89+
LuminaNextDiT2DModel,
8890
PixArtTransformer2DModel,
8991
PriorTransformer,
9092
SD3Transformer2DModel,

0 commit comments

Comments
 (0)