Skip to content

Commit 9941b3f

Browse files
haofanwangResearcherXmansayakpaul
authored
Add InstantID Pipeline (#6673)
* add instantid pipeline * format * Update README.md * Update README.md * format --------- Co-authored-by: ResearcherXman <xhs.research@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 16b9f98 commit 9941b3f

File tree

2 files changed

+1130
-2
lines changed

2 files changed

+1130
-2
lines changed

examples/community/README.md

+72-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
6060
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
6161
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
6262
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
63-
64-
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
63+
IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
64+
InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
6565

6666
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
6767
```py
@@ -3533,3 +3533,73 @@ images = pipeline(
35333533
for i in range(num_images):
35343534
images[i].save(f"c{i}.png")
35353535
```
3536+
3537+
### InstantID Pipeline
3538+
3539+
InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID).
3540+
3541+
```py
3542+
# !pip install opencv-python transformers accelerate insightface
3543+
import diffusers
3544+
from diffusers.utils import load_image
3545+
from diffusers.models import ControlNetModel
3546+
3547+
import cv2
3548+
import torch
3549+
import numpy as np
3550+
from PIL import Image
3551+
3552+
from insightface.app import FaceAnalysis
3553+
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
3554+
3555+
# prepare 'antelopev2' under ./models
3556+
# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304
3557+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
3558+
app.prepare(ctx_id=0, det_size=(640, 640))
3559+
3560+
# prepare models under ./checkpoints
3561+
# https://huggingface.co/InstantX/InstantID
3562+
from huggingface_hub import hf_hub_download
3563+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
3564+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
3565+
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
3566+
3567+
face_adapter = f'./checkpoints/ip-adapter.bin'
3568+
controlnet_path = f'./checkpoints/ControlNetModel'
3569+
3570+
# load IdentityNet
3571+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
3572+
3573+
base_model = 'wangqixun/YamerMIX_v8'
3574+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
3575+
base_model,
3576+
controlnet=controlnet,
3577+
torch_dtype=torch.float16
3578+
)
3579+
pipe.cuda()
3580+
3581+
# load adapter
3582+
pipe.load_ip_adapter_instantid(face_adapter)
3583+
3584+
# load an image
3585+
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
3586+
3587+
# prepare face emb
3588+
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
3589+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
3590+
face_emb = face_info['embedding']
3591+
face_kps = draw_kps(face_image, face_info['kps'])
3592+
3593+
# prompt
3594+
prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic"
3595+
negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful"
3596+
3597+
# generate image
3598+
pipe.set_ip_adapter_scale(0.8)
3599+
image = pipe(
3600+
prompt,
3601+
image_embeds=face_emb,
3602+
image=face_kps,
3603+
controlnet_conditioning_scale=0.8,
3604+
).images[0]
3605+
```

0 commit comments

Comments
 (0)