-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathbenchmark_image.py
234 lines (192 loc) · 8.6 KB
/
benchmark_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import torch
# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")
from diffusers import DiffusionPipeline
from torchao.quantization import quantize_, autoquant
import argparse
import json
from utils import cleanup_tmp_directory, benchmark_fn, pretty_print_results, reset_memory, bytes_to_giga_bytes
PROMPT = "Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture."
PREFIXES = {
"stabilityai/stable-diffusion-3-medium-diffusers": "sd3",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart",
"fal/AuraFlow": "auraflow",
"black-forest-labs/FLUX.1-dev": "flux-dev",
}
def load_pipeline(
ckpt_id: str,
fuse_attn_projections: bool,
compile: bool,
quantization: str,
sparsify: bool,
compile_vae: bool = False,
) -> DiffusionPipeline:
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.bfloat16).to("cuda")
if fuse_attn_projections:
pipeline.transformer.fuse_qkv_projections()
if compile_vae:
pipeline.vae.fuse_qkv_projections()
if quantization == "autoquant" and compile:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
if compile_vae:
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
if not sparsify:
if quantization == "int8dq":
from torchao.quantization import int8_dynamic_activation_int8_weight
quantize_(pipeline.transformer, int8_dynamic_activation_int8_weight())
if compile_vae:
quantize_(pipeline.vae, int8_dynamic_activation_int8_weight())
elif quantization == "int8wo":
from torchao.quantization import int8_weight_only
quantize_(pipeline.transformer, int8_weight_only())
if compile_vae:
quantize_(pipeline.vae, int8_weight_only())
elif quantization == "int4wo":
from torchao.quantization import int4_weight_only
quantize_(pipeline.transformer, int4_weight_only())
if compile_vae:
quantize_(pipeline.vae, int4_weight_only())
elif quantization == "fp6_e3m2":
from torchao.quantization import fpx_weight_only
quantize_(pipeline.transformer, fpx_weight_only(3, 2))
if compile_vae:
quantize_(pipeline.vae, fpx_weight_only(3, 2))
elif quantization == "fp5_e2m2":
from torchao.quantization import fpx_weight_only
quantize_(pipeline.transformer, fpx_weight_only(2, 2))
if compile_vae:
quantize_(pipeline.vae, fpx_weight_only(2, 2))
elif quantization == "fp4_e2m1":
from torchao.quantization import fpx_weight_only
quantize_(pipeline.transformer, fpx_weight_only(2, 1))
if compile_vae:
quantize_(pipeline.vae, fpx_weight_only(2, 1))
elif quantization == "fp8wo":
from torchao.quantization import float8_weight_only
quantize_(pipeline.transformer, float8_weight_only())
if compile_vae:
quantize_(pipeline.vae, float8_weight_only())
elif quantization == "fp8dq":
from torchao.quantization import float8_dynamic_activation_float8_weight
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
if compile_vae:
quantize_(pipeline.vae, float8_dynamic_activation_float8_weight())
elif quantization == "fp8dqrow":
from torchao.quantization import float8_dynamic_activation_float8_weight
from torchao.quantization.quant_api import PerRow
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()))
if compile_vae:
quantize_(pipeline.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()))
elif quantization == "autoquant":
pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
if compile_vae:
pipeline.vae = autoquant(pipeline.vae, error_on_unseen=False)
if sparsify:
from torchao.sparsity import sparsify_, int8_dynamic_activation_int8_semi_sparse_weight
sparsify_(pipeline.transformer, int8_dynamic_activation_int8_semi_sparse_weight())
if compile_vae:
sparsify_(pipeline.vae, int8_dynamic_activation_int8_semi_sparse_weight())
if quantization != "autoquant" and compile:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
if compile_vae:
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def run_inference(pipe, batch_size):
_ = pipe(
prompt=PROMPT,
num_images_per_prompt=batch_size,
generator=torch.manual_seed(2024),
)
def run_benchmark(pipeline, args):
model_memory = bytes_to_giga_bytes(torch.cuda.memory_allocated()) # in GBs.
for _ in range(5):
run_inference(pipeline, batch_size=args.batch_size)
time = benchmark_fn(run_inference, pipeline, args.batch_size)
torch.cuda.empty_cache()
inference_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
info = dict(
ckpt_id=args.ckpt_id,
batch_size=args.batch_size,
fuse=args.fuse_attn_projections,
compile=args.compile,
compile_vae=args.compile_vae,
quantization=args.quantization,
sparsify=args.sparsify,
model_memory=model_memory,
inference_memory=inference_memory,
time=time,
)
pretty_print_results(info)
return info
def serialize_artifacts(info: dict, pipeline, args):
ckpt_id = PREFIXES[args.ckpt_id]
prefix = f"ckpt@{ckpt_id}-bs@{args.batch_size}-fuse@{args.fuse_attn_projections}-compile@{args.compile}-compile_vae@{args.compile_vae}-quant@{args.quantization}-sparsify@{args.sparsify}"
info_file = f"{prefix}_info.json"
with open(info_file, "w") as f:
json.dump(info, f)
image = pipeline(
prompt=PROMPT,
num_images_per_prompt=args.batch_size,
generator=torch.manual_seed(0),
).images[0]
image.save(f"{prefix}.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt_id",
default="black-forest-labs/FLUX.1-dev",
type=str,
help="Hub model or path to local model for which the benchmark is to be run.",
)
parser.add_argument(
"--fuse_attn_projections",
action="store_true",
help="Whether or not to fuse the QKV projection layers into one larger layer.",
)
parser.add_argument("--compile", action="store_true", help="Whether or not to torch.compile the models.")
parser.add_argument("--compile_vae", action="store_true", help="If compiling, should VAE be compiled too?")
parser.add_argument(
"--quantization",
default="None",
choices=[
"int8dq",
"int8wo",
"int4wo",
"autoquant",
"fp8wo",
"fp8dq",
"fp8dqrow",
"fp6_e3m2",
"fp5_e2m2",
"fp4_e2m1",
"None",
],
help="Which quantization technique to apply",
)
parser.add_argument("--sparsify", action="store_true")
parser.add_argument(
"--batch_size",
default=1,
type=int,
choices=[1, 4, 8, 16],
help="Number of images to generate for the testing prompt.",
)
args = parser.parse_args()
reset_memory("cuda")
pipeline = load_pipeline(
ckpt_id=args.ckpt_id,
fuse_attn_projections=args.fuse_attn_projections,
compile=args.compile,
compile_vae=args.compile_vae,
quantization=args.quantization,
sparsify=args.sparsify,
)
info = run_benchmark(pipeline, args)
serialize_artifacts(info, pipeline, args)
cleanup_tmp_directory()