Skip to content

Commit

Permalink
batching added
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Apr 26, 2024
1 parent dd7a750 commit 3b42a92
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions cog/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
from transformers import CLIPImageProcessor
from controlnet_util import openpose, get_depth_map, get_canny_image

from pipeline_stable_diffusion_xl_instantid_full import (
StableDiffusionXLInstantIDPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand Down Expand Up @@ -137,13 +134,16 @@
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar",
"path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel",
},
# TODO: Get safetensors working w LCM (for a future piece of work) - also removed from `choices` list in `sdxl_weights` (predict input param)
# These are non-huggingface models, e.g. .safetensors files
# "RealVisXL_V3.0": {
# "url": "https://weights.replicate.delivery/default/comfy-ui/checkpoints/RealVisXL_V3.0.safetensors.tar",
# "path": "checkpoints/RealVisXL_V3.0",
# "file": "RealVisXL_V3.0.safetensors",
# },
"RealVisXL_V3.0_Turbo": {
"slug": "SG161222/RealVisXL_V3.0_Turbo",
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar",
"path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo",
},
"RealVisXL_V4.0_Lightning": {
"slug": "SG161222/RealVisXL_V4.0_Lightning",
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar",
"path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning",
},
}


Expand Down Expand Up @@ -388,6 +388,7 @@ def generate_image(
scheduler,
enable_LCM,
enhance_face_region,
num_images_per_prompt,
):
if enable_LCM:
self.pipe.enable_lora()
Expand Down Expand Up @@ -502,6 +503,7 @@ def generate_image(
height=height,
width=width,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
).images

return images
Expand Down Expand Up @@ -542,6 +544,8 @@ def predict(
"omnigen-xl",
"pony-diffusion-v6-xl",
"protovision-xl-high-fidel",
"RealVisXL_V3.0_Turbo",
"RealVisXL_V4.0_Lightning",
],
),
scheduler: str = Input(
Expand Down Expand Up @@ -644,6 +648,12 @@ def predict(
description="Random seed. Leave blank to randomize the seed",
default=None,
),
num_outputs: int = Input(
description="Number of images to output",
default=1,
ge=1,
le=8,
),
disable_safety_checker: bool = Input(
description="Disable safety checker for generated images",
default=False,
Expand Down Expand Up @@ -692,6 +702,7 @@ def predict(
seed=seed,
enable_LCM=enable_lcm,
enhance_face_region=enhance_nonface_region,
num_images_per_prompt=num_outputs,
)

# Save the generated images and check for NSFW content
Expand All @@ -705,7 +716,7 @@ def predict(
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt."
)

extension = output_format.lower()
extension = "jpeg" if extension == "jpg" else extension
output_path = f"/tmp/out_{i}.{extension}"
Expand Down

0 comments on commit 3b42a92

Please # to comment.