diff --git a/fp8/flux_pipeline.py b/fp8/flux_pipeline.py index c208abf..55322e8 100644 --- a/fp8/flux_pipeline.py +++ b/fp8/flux_pipeline.py @@ -615,15 +615,16 @@ def generate( ) # prepare inputs - img, img_ids, vec, txt, txt_ids = map( - lambda x: x, # x.contiguous(), - self.prepare( - img=img, - prompt=prompt, - target_device=self.device_flux, - target_dtype=self.dtype, - ), - ) + with torch.profiler.record_function("prepare"): + img, img_ids, vec, txt, txt_ids = map( + lambda x: x, # x.contiguous(), + self.prepare( + img=img, + prompt=prompt, + target_device=self.device_flux, + target_dtype=self.dtype, + ), + ) # dispatch to gpu if offloaded if self.offload_flow: @@ -634,16 +635,17 @@ def generate( output_imgs = [] for i in range(batch_size): - denoised_img = self.denoise_single_item( - img[i], - img_ids[i], - txt[i], - txt_ids[i], - vec[i], - timesteps, - guidance, - compiling - ) + with torch.profiler.record_function("denoise-single-item"): + denoised_img = self.denoise_single_item( + img[i], + img_ids[i], + txt[i], + txt_ids[i], + vec[i], + timesteps, + guidance, + compiling + ) output_imgs.append(denoised_img) compiling = False @@ -655,7 +657,8 @@ def generate( torch.cuda.empty_cache() # decode latents to pixel space - img = self.vae_decode(img, height, width) + with torch.profiler.record_function("vae-decode"): + img = self.vae_decode(img, height, width) return self.as_img_tensor(img) diff --git a/predict.py b/predict.py index 98cf382..dc6be74 100644 --- a/predict.py +++ b/predict.py @@ -1,3 +1,4 @@ +import contextlib import os import time from typing import Any, Tuple, Optional @@ -136,8 +137,10 @@ def base_setup( compile_fp8: bool = False, compile_bf16: bool = False, disable_fp8: bool = False, + enable_profiling: bool = False, ) -> None: self.flow_model_name = flow_model_name + self.enable_profiling = enable_profiling print(f"Booting model {self.flow_model_name}") gpu_name = ( @@ -477,6 +480,7 @@ def postprocess( output_format: str, output_quality: int, np_images: Optional[List[Image]] = None, + profile: Optional[Path] = None, ) -> List[Path]: has_nsfw_content = [False] * len(images) @@ -513,6 +517,8 @@ def postprocess( ) print(f"Total safe images: {len(output_paths)} out of {len(images)}") + if profile: + output_paths.append(profile) return output_paths def run_safety_checker(self, images, np_images): @@ -547,32 +553,48 @@ def shared_predict( seed: int = None, width: int = 1024, height: int = 1024, - ): - if go_fast and not self.disable_fp8: - return self.fp8_predict( - prompt=prompt, - num_outputs=num_outputs, - num_inference_steps=num_inference_steps, - guidance=guidance, - image=image, - prompt_strength=prompt_strength, - seed=seed, - width=width, - height=height, + ) -> Tuple[List[Image.Image], Optional[List[np.ndarray]], Optional[Path]]: + if self.enable_profiling: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] ) - if self.disable_fp8: - print("running bf16 model, fp8 disabled") - return self.base_predict( - prompt=prompt, - num_outputs=num_outputs, - num_inference_steps=num_inference_steps, - guidance=guidance, - image=image, - prompt_strength=prompt_strength, - seed=seed, - width=width, - height=height, - ) + else: + profiler = contextlib.nullcontext() + + with profiler: + if go_fast and not self.disable_fp8: + imgs, np_imgs = self.fp8_predict( + prompt=prompt, + num_outputs=num_outputs, + num_inference_steps=num_inference_steps, + guidance=guidance, + image=image, + prompt_strength=prompt_strength, + seed=seed, + width=width, + height=height, + ) + else: + if self.disable_fp8: + print("running bf16 model, fp8 disabled") + imgs, np_imgs = self.base_predict( + prompt=prompt, + num_outputs=num_outputs, + num_inference_steps=num_inference_steps, + guidance=guidance, + image=image, + prompt_strength=prompt_strength, + seed=seed, + width=width, + height=height, + ) + if isinstance(profiler, torch.profiler.profile): + profiler.export_chrome_trace("chrome-trace.json") + return imgs, np_imgs, Path("chrome-trace.json") + return imgs, np_imgs, None class SchnellPredictor(Predictor): @@ -598,7 +620,7 @@ def predict( megapixels: str = SHARED_INPUTS.megapixels, ) -> List[Path]: width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + imgs, np_imgs, profile = self.shared_predict( go_fast, prompt, num_outputs, @@ -614,6 +636,7 @@ def predict( output_format, output_quality, np_images=np_imgs, + profile=profile, ) @@ -656,7 +679,7 @@ def predict( print("img2img not supported with fp8 quantization; running with bf16") go_fast = False width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + imgs, np_imgs, profile = self.shared_predict( go_fast, prompt, num_outputs, @@ -675,6 +698,7 @@ def predict( output_format, output_quality, np_images=np_imgs, + profile=profile, ) @@ -706,7 +730,7 @@ def predict( self.handle_loras(go_fast, lora_weights, lora_scale) width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + imgs, np_imgs, profile = self.shared_predict( go_fast, prompt, num_outputs, @@ -722,6 +746,7 @@ def predict( output_format, output_quality, np_images=np_imgs, + profile=profile, ) @@ -770,7 +795,7 @@ def predict( self.handle_loras(go_fast, lora_weights, lora_scale) width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + imgs, np_imgs, profile = self.shared_predict( go_fast, prompt, num_outputs, @@ -789,6 +814,7 @@ def predict( output_format, output_quality, np_images=np_imgs, + profile=profile, )