From 3b42a922602562f618497348a0751646251a16d0 Mon Sep 17 00:00:00 2001 From: zsxkib Date: Fri, 26 Apr 2024 15:51:19 +0000 Subject: [PATCH] batching added --- cog/predict.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/cog/predict.py b/cog/predict.py index e839f88b..61f343b0 100644 --- a/cog/predict.py +++ b/cog/predict.py @@ -30,9 +30,6 @@ from transformers import CLIPImageProcessor from controlnet_util import openpose, get_depth_map, get_canny_image -from pipeline_stable_diffusion_xl_instantid_full import ( - StableDiffusionXLInstantIDPipeline, -) from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) @@ -137,13 +134,16 @@ "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar", "path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel", }, - # TODO: Get safetensors working w LCM (for a future piece of work) - also removed from `choices` list in `sdxl_weights` (predict input param) - # These are non-huggingface models, e.g. .safetensors files - # "RealVisXL_V3.0": { - # "url": "https://weights.replicate.delivery/default/comfy-ui/checkpoints/RealVisXL_V3.0.safetensors.tar", - # "path": "checkpoints/RealVisXL_V3.0", - # "file": "RealVisXL_V3.0.safetensors", - # }, + "RealVisXL_V3.0_Turbo": { + "slug": "SG161222/RealVisXL_V3.0_Turbo", + "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar", + "path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo", + }, + "RealVisXL_V4.0_Lightning": { + "slug": "SG161222/RealVisXL_V4.0_Lightning", + "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar", + "path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning", + }, } @@ -388,6 +388,7 @@ def generate_image( scheduler, enable_LCM, enhance_face_region, + num_images_per_prompt, ): if enable_LCM: self.pipe.enable_lora() @@ -502,6 +503,7 @@ def generate_image( height=height, width=width, generator=generator, + num_images_per_prompt=num_images_per_prompt, ).images return images @@ -542,6 +544,8 @@ def predict( "omnigen-xl", "pony-diffusion-v6-xl", "protovision-xl-high-fidel", + "RealVisXL_V3.0_Turbo", + "RealVisXL_V4.0_Lightning", ], ), scheduler: str = Input( @@ -644,6 +648,12 @@ def predict( description="Random seed. Leave blank to randomize the seed", default=None, ), + num_outputs: int = Input( + description="Number of images to output", + default=1, + ge=1, + le=8, + ), disable_safety_checker: bool = Input( description="Disable safety checker for generated images", default=False, @@ -692,6 +702,7 @@ def predict( seed=seed, enable_LCM=enable_lcm, enhance_face_region=enhance_nonface_region, + num_images_per_prompt=num_outputs, ) # Save the generated images and check for NSFW content @@ -705,7 +716,7 @@ def predict( raise Exception( "NSFW content detected. Try running it again, or try a different prompt." ) - + extension = output_format.lower() extension = "jpeg" if extension == "jpg" else extension output_path = f"/tmp/out_{i}.{extension}"