diff --git a/examples/community/README.md b/examples/community/README.md index f467ee38de3b..652d65f900fe 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1641,18 +1641,18 @@ from io import BytesIO from PIL import Image import torch from diffusers import DDIMScheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline +from diffusers import DiffusionPipeline # Use the DDIMScheduler scheduler here instead scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") -pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", - custom_pipeline="stable_diffusion_tensorrt_img2img", - variant='fp16', - torch_dtype=torch.float16, - scheduler=scheduler,) +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", + custom_pipeline="stable_diffusion_tensorrt_img2img", + variant='fp16', + torch_dtype=torch.float16, + scheduler=scheduler,) # re-use cached folder to save ONNX models and TensorRT Engines pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',) diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index 16a8b803cc29..40ad38bfe903 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -18,8 +18,7 @@ import gc import os from collections import OrderedDict -from copy import copy -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import onnx @@ -27,9 +26,11 @@ import PIL.Image import tensorrt as trt import torch +from cuda import cudart from huggingface_hub import snapshot_download from huggingface_hub.utils import validate_hf_hub_args from onnx import shape_inference +from packaging import version from polygraphy import cuda from polygraphy.backend.common import bytes_from_path from polygraphy.backend.onnx.loader import fold_constants @@ -41,12 +42,13 @@ network_from_onnx_path, save_engine, ) -from polygraphy.backend.trt import util as trt_util from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict, deprecate +from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( - StableDiffusionImg2ImgPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) @@ -58,7 +60,7 @@ """ Installation instructions python3 -m pip install --upgrade transformers diffusers>=0.16.0 -python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade tensorrt-cu12==10.2.0 python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -88,10 +90,6 @@ torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} -def device_view(t): - return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) - - def preprocess_image(image): """ image: torch.Tensor @@ -125,10 +123,8 @@ def build( onnx_path, fp16, input_profile=None, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - workspace_size=0, ): logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") p = Profile() @@ -137,20 +133,13 @@ def build( assert len(dims) == 3 p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - config_kwargs = {} - - config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] - if enable_preview: - # Faster dynamic shapes made optional since it increases engine build time. - config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) - if workspace_size > 0: - config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + extra_build_args = {} if not enable_all_tactics: - config_kwargs["tactic_sources"] = [] + extra_build_args["tactic_sources"] = [] engine = engine_from_network( network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), - config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args), save_timing_cache=timing_cache, ) save_engine(engine, path=self.engine_path) @@ -163,28 +152,24 @@ def activate(self): self.context = self.engine.create_execution_context() def allocate_buffers(self, shape_dict=None, device="cuda"): - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - if shape_dict and binding in shape_dict: - shape = shape_dict[binding] + for binding in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(binding) + if shape_dict and name in shape_dict: + shape = shape_dict[name] else: - shape = self.engine.get_binding_shape(binding) - dtype = trt.nptype(self.engine.get_binding_dtype(binding)) - if self.engine.binding_is_input(binding): - self.context.set_binding_shape(idx, shape) + shape = self.engine.get_tensor_shape(name) + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(name, shape) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) - self.tensors[binding] = tensor - self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + self.tensors[name] = tensor def infer(self, feed_dict, stream): - start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) - # shallow copy of ordered dict - device_buffers = copy(self.buffers) for name, buf in feed_dict.items(): - assert isinstance(buf, cuda.DeviceView) - device_buffers[name] = buf - bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] - noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + self.tensors[name].copy_(buf) + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + noerror = self.context.execute_async_v3(stream) if not noerror: raise ValueError("ERROR: inference failed.") @@ -325,10 +310,8 @@ def build_engines( force_engine_rebuild=False, static_batch=False, static_shape=True, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - max_workspace_size=0, ): built_engines = {} if not os.path.isdir(onnx_dir): @@ -393,9 +376,7 @@ def build_engines( static_batch=static_batch, static_shape=static_shape, ), - enable_preview=enable_preview, timing_cache=timing_cache, - workspace_size=max_workspace_size, ) built_engines[model_name] = engine @@ -674,7 +655,7 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False) return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) -class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): +class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): r""" Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. @@ -702,6 +683,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + def __init__( self, vae: AutoencoderKL, @@ -722,24 +705,86 @@ def __init__( onnx_dir: str = "onnx", # TensorRT engine build parameters engine_dir: str = "engine", - build_preview_features: bool = True, force_engine_rebuild: bool = False, timing_cache: str = "timing_cache", ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - scheduler, + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - requires_safety_checker=requires_safety_checker, ) - self.vae.forward = self.vae.decode - self.stages = stages self.image_height, self.image_width = image_height, image_width self.inpaint = False @@ -750,7 +795,6 @@ def __init__( self.timing_cache = timing_cache self.build_static_batch = False self.build_dynamic_shape = False - self.build_preview_features = build_preview_features self.max_batch_size = max_batch_size # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. @@ -761,6 +805,11 @@ def __init__( self.models = {} # loaded in __loadModels() self.engine = {} # loaded in build_engines() + self.vae.forward = self.vae.decode + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + def __loadModels(self): # Load pipeline models self.embedding_dim = self.text_encoder.config.hidden_size @@ -779,6 +828,33 @@ def __loadModels(self): if "vae_encoder" in self.stages: self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker( + self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype + ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: + r""" + Runs the safety checker on the given image. + Args: + image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked. + device (torch.device): The device to run the safety checker on. + dtype (torch.dtype): The data type of the input image. + Returns: + (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and + a boolean indicating whether the image has a NSFW (Not Safe for Work) concept. + """ + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + @classmethod @validate_hf_hub_args def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -826,7 +902,6 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dt force_engine_rebuild=self.force_engine_rebuild, static_batch=self.build_static_batch, static_shape=not self.build_dynamic_shape, - enable_preview=self.build_preview_features, timing_cache=self.timing_cache, ) @@ -850,9 +925,7 @@ def __preprocess_images(self, batch_size, images=()): return tuple(init_images) def __encode_image(self, init_image): - init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ - "latent" - ] + init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"] init_latents = 0.18215 * init_latents return init_latents @@ -881,9 +954,8 @@ def __encode_prompt(self, prompt, negative_prompt): .to(self.torch_device) ) - text_input_ids_inp = device_view(text_input_ids) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[ "text_embeddings" ].clone() @@ -899,8 +971,7 @@ def __encode_prompt(self, prompt, negative_prompt): .input_ids.type(torch.int32) .to(self.torch_device) ) - uncond_input_ids_inp = device_view(uncond_input_ids) - uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[ "text_embeddings" ] @@ -924,18 +995,15 @@ def __denoise_latent( # Predict the noise residual timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - sample_inp = device_view(latent_model_input) - timestep_inp = device_view(timestep_float) - embeddings_inp = device_view(text_embeddings) noise_pred = runEngine( self.engine["unet"], - {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, self.stream, )["latent"] # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample @@ -943,12 +1011,12 @@ def __denoise_latent( return latents def __decode_latent(self, latents): - images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"] images = (images / 2 + 0.5).clamp(0, 1) return images.cpu().permute(0, 2, 3, 1).float().numpy() def __loadResources(self, image_height, image_width, batch_size): - self.stream = cuda.Stream() + self.stream = cudart.cudaStreamCreate()[1] # Allocate buffers for TensorRT engine bindings for model_name, obj in self.models.items(): @@ -1061,5 +1129,6 @@ def __call__( # VAE decode latent images = self.__decode_latent(latents) + images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) - return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)