Skip to content

Commit 7f65e1c

Browse files
authored
Support ControlNet (#153)
* add controlnet tentatively * add controlnet in python code * implement swift part * support 8-bit quantization * add controlnet unload when reduce memory * remove irrelevant changes * add more description about controlnet option in swift * fix some for pr and update README * pre-allocate zero shapedArray + make multi-controlnet faster
1 parent d1a6888 commit 7f65e1c

14 files changed

+966
-66
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
139139

140140
- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.
141141

142+
- `--convert-controlnet`: Converts ControlNet models specified after this option. This can also convert multiple models if you specify like `--convert-controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth`.
143+
144+
- `--unet-support-controlnet`: enables a converted UNet model to receive additional inputs from ControlNet. This is required for generating image with using ControlNet and saved with a different name, `*_control-unet.mlpackage`, distinct from normal UNet. On the other hand, this UNet model can not work without ControlNet. Please use normal UNet for just txt2img.
145+
142146
</details>
143147

144148
## <a name="image-generation-with-python"></a> Image Generation with Python
@@ -157,6 +161,8 @@ Please refer to the help menu for all available arguments: `python -m python_cor
157161
- `--model-version`: If you overrode the default model version while converting models to Core ML, you will need to specify the same model version here.
158162
- `--compute-unit`: Note that the most performant compute unit for this particular implementation may differ across different hardware. `CPU_AND_GPU` or `CPU_AND_NE` may be faster than `ALL`. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
159163
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
164+
- `--controlnet`: ControlNet models specified with this option are used in image generation. Use this option in the format `--controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth` and make sure to use `--controlnet-inputs` in conjunction.
165+
- `--controlnet-inputs`: Image inputs corresponding to each ControlNet model. Please provide image paths in same order as models in `--controlnet`, for example: `--controlnet-inputs image_mlsd image_depth`.
160166

161167
</details>
162168

@@ -228,6 +234,14 @@ Optionally, it may also include the safety checker model that some versions of S
228234

229235
- `SafetyChecker.mlmodelc`
230236

237+
Optionally, for ControlNet:
238+
239+
- `ControlledUNet.mlmodelc` or `ControlledUnetChunk1.mlmodelc` & `ControlledUnetChunk2.mlmodelc` (enabled to receive ControlNet values)
240+
- `controlnet/` (directory containing ControlNet models)
241+
- `LllyasvielSdControlnetMlsd.mlmodelc` (for example, from lllyasviel/sd-controlnet-mlsd)
242+
- `LllyasvielSdControlnetDepth.mlmodelc` (for example, from lllyasviel/sd-controlnet-depth)
243+
- Other models you converted
244+
231245
Note that the chunked version of Unet is checked for first. Only if it is not present will the full `Unet.mlmodelc` be loaded. Chunking is required for iOS and iPadOS and not necessary for macOS.
232246

233247
</details>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#
2+
# For licensing see accompanying LICENSE.md file.
3+
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
4+
#
5+
6+
from diffusers.configuration_utils import ConfigMixin, register_to_config
7+
from diffusers import ModelMixin
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map
14+
15+
class ControlNetConditioningEmbedding(nn.Module):
16+
17+
def __init__(
18+
self,
19+
conditioning_embedding_channels,
20+
conditioning_channels=3,
21+
block_out_channels=(16, 32, 96, 256),
22+
):
23+
super().__init__()
24+
25+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
26+
27+
self.blocks = nn.ModuleList([])
28+
29+
for i in range(len(block_out_channels) - 1):
30+
channel_in = block_out_channels[i]
31+
channel_out = block_out_channels[i + 1]
32+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
33+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
34+
35+
self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
36+
37+
def forward(self, conditioning):
38+
embedding = self.conv_in(conditioning)
39+
embedding = F.silu(embedding)
40+
41+
for block in self.blocks:
42+
embedding = block(embedding)
43+
embedding = F.silu(embedding)
44+
45+
embedding = self.conv_out(embedding)
46+
47+
return embedding
48+
49+
class ControlNetModel(ModelMixin, ConfigMixin):
50+
51+
@register_to_config
52+
def __init__(
53+
self,
54+
in_channels=4,
55+
flip_sin_to_cos=True,
56+
freq_shift=0,
57+
down_block_types=(
58+
"CrossAttnDownBlock2D",
59+
"CrossAttnDownBlock2D",
60+
"CrossAttnDownBlock2D",
61+
"DownBlock2D",
62+
),
63+
only_cross_attention=False,
64+
block_out_channels=(320, 640, 1280, 1280),
65+
layers_per_block=2,
66+
downsample_padding=1,
67+
mid_block_scale_factor=1,
68+
act_fn="silu",
69+
norm_num_groups=32,
70+
norm_eps=1e-5,
71+
cross_attention_dim=1280,
72+
attention_head_dim=8,
73+
use_linear_projection=False,
74+
upcast_attention=False,
75+
resnet_time_scale_shift="default",
76+
conditioning_embedding_out_channels=(16, 32, 96, 256),
77+
**kwargs,
78+
):
79+
super().__init__()
80+
81+
# Check inputs
82+
if len(block_out_channels) != len(down_block_types):
83+
raise ValueError(
84+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
85+
)
86+
87+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
88+
raise ValueError(
89+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
90+
)
91+
92+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
93+
raise ValueError(
94+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
95+
)
96+
97+
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
98+
99+
# input
100+
conv_in_kernel = 3
101+
conv_in_padding = (conv_in_kernel - 1) // 2
102+
self.conv_in = nn.Conv2d(
103+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
104+
)
105+
106+
# time
107+
time_embed_dim = block_out_channels[0] * 4
108+
109+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
110+
timestep_input_dim = block_out_channels[0]
111+
112+
self.time_embedding = TimestepEmbedding(
113+
timestep_input_dim,
114+
time_embed_dim,
115+
)
116+
117+
# control net conditioning embedding
118+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
119+
conditioning_embedding_channels=block_out_channels[0],
120+
block_out_channels=conditioning_embedding_out_channels,
121+
)
122+
123+
self.down_blocks = nn.ModuleList([])
124+
self.controlnet_down_blocks = nn.ModuleList([])
125+
126+
if isinstance(only_cross_attention, bool):
127+
only_cross_attention = [only_cross_attention] * len(down_block_types)
128+
129+
if isinstance(attention_head_dim, int):
130+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
131+
132+
# down
133+
output_channel = block_out_channels[0]
134+
135+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
136+
self.controlnet_down_blocks.append(controlnet_block)
137+
138+
for i, down_block_type in enumerate(down_block_types):
139+
input_channel = output_channel
140+
output_channel = block_out_channels[i]
141+
is_final_block = i == len(block_out_channels) - 1
142+
143+
down_block = get_down_block(
144+
down_block_type,
145+
num_layers=layers_per_block,
146+
in_channels=input_channel,
147+
out_channels=output_channel,
148+
temb_channels=time_embed_dim,
149+
resnet_eps=norm_eps,
150+
resnet_act_fn=act_fn,
151+
cross_attention_dim=cross_attention_dim,
152+
attn_num_head_channels=attention_head_dim[i],
153+
downsample_padding=downsample_padding,
154+
)
155+
self.down_blocks.append(down_block)
156+
157+
for _ in range(layers_per_block):
158+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
159+
self.controlnet_down_blocks.append(controlnet_block)
160+
161+
if not is_final_block:
162+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
163+
self.controlnet_down_blocks.append(controlnet_block)
164+
165+
# mid
166+
mid_block_channel = block_out_channels[-1]
167+
168+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
169+
self.controlnet_mid_block = controlnet_block
170+
171+
self.mid_block = UNetMidBlock2DCrossAttn(
172+
in_channels=mid_block_channel,
173+
temb_channels=time_embed_dim,
174+
resnet_eps=norm_eps,
175+
resnet_act_fn=act_fn,
176+
output_scale_factor=mid_block_scale_factor,
177+
resnet_time_scale_shift=resnet_time_scale_shift,
178+
cross_attention_dim=cross_attention_dim,
179+
attn_num_head_channels=attention_head_dim[-1],
180+
resnet_groups=norm_num_groups,
181+
use_linear_projection=use_linear_projection,
182+
upcast_attention=upcast_attention,
183+
)
184+
185+
def get_num_residuals(self):
186+
num_res = 2 # initial sample + mid block
187+
for down_block in self.down_blocks:
188+
num_res += len(down_block.resnets)
189+
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
190+
num_res += len(down_block.downsamplers)
191+
return num_res
192+
193+
def forward(
194+
self,
195+
sample,
196+
timestep,
197+
encoder_hidden_states,
198+
controlnet_cond,
199+
):
200+
# 1. time
201+
t_emb = self.time_proj(timestep)
202+
emb = self.time_embedding(t_emb)
203+
204+
# 2. pre-process
205+
sample = self.conv_in(sample)
206+
207+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
208+
209+
sample += controlnet_cond
210+
211+
# 3. down
212+
down_block_res_samples = (sample,)
213+
for downsample_block in self.down_blocks:
214+
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
215+
sample, res_samples = downsample_block(
216+
hidden_states=sample,
217+
temb=emb,
218+
encoder_hidden_states=encoder_hidden_states,
219+
)
220+
else:
221+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
222+
223+
down_block_res_samples += res_samples
224+
225+
# 4. mid
226+
if self.mid_block is not None:
227+
sample = self.mid_block(
228+
sample,
229+
emb,
230+
encoder_hidden_states=encoder_hidden_states,
231+
)
232+
233+
# 5. Control net blocks
234+
controlnet_down_block_res_samples = ()
235+
236+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
237+
down_block_res_sample = controlnet_block(down_block_res_sample)
238+
controlnet_down_block_res_samples += (down_block_res_sample,)
239+
240+
down_block_res_samples = controlnet_down_block_res_samples
241+
242+
mid_block_res_sample = self.controlnet_mid_block(sample)
243+
244+
return down_block_res_samples, mid_block_res_sample

python_coreml_stable_diffusion/coreml_model.py

+17
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,22 @@ def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
9898

9999
return CoreMLModel(mlpackage_path, compute_unit)
100100

101+
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
102+
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
103+
"""
104+
model_name = model_version.replace("/", "_")
105+
106+
logger.info(f"Loading controlnet_{model_name} mlpackage")
107+
108+
fname = f"ControlNet_{model_name}.mlpackage"
109+
110+
mlpackage_path = os.path.join(mlpackages_dir, fname)
111+
112+
if not os.path.exists(mlpackage_path):
113+
raise FileNotFoundError(
114+
f"controlnet_{model_name} CoreML model doesn't exist at {mlpackage_path}")
115+
116+
return CoreMLModel(mlpackage_path, compute_unit)
117+
101118
def get_available_compute_units():
102119
return tuple(cu for cu in ct.ComputeUnit._member_names_)

0 commit comments

Comments
 (0)