Skip to content

Commit 55463f7

Browse files
Z-Image-Turbo ControlNet (#12792)
* init --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f9c1e61 commit 55463f7

File tree

13 files changed

+2409
-6
lines changed

13 files changed

+2409
-6
lines changed

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@
279279
"WanAnimateTransformer3DModel",
280280
"WanTransformer3DModel",
281281
"WanVACETransformer3DModel",
282+
"ZImageControlNetModel",
282283
"ZImageTransformer2DModel",
283284
"attention_backend",
284285
]
@@ -670,6 +671,8 @@
670671
"WuerstchenCombinedPipeline",
671672
"WuerstchenDecoderPipeline",
672673
"WuerstchenPriorPipeline",
674+
"ZImageControlNetInpaintPipeline",
675+
"ZImageControlNetPipeline",
673676
"ZImageImg2ImgPipeline",
674677
"ZImagePipeline",
675678
]
@@ -1017,6 +1020,7 @@
10171020
WanAnimateTransformer3DModel,
10181021
WanTransformer3DModel,
10191022
WanVACETransformer3DModel,
1023+
ZImageControlNetModel,
10201024
ZImageTransformer2DModel,
10211025
attention_backend,
10221026
)
@@ -1377,6 +1381,8 @@
13771381
WuerstchenCombinedPipeline,
13781382
WuerstchenDecoderPipeline,
13791383
WuerstchenPriorPipeline,
1384+
ZImageControlNetInpaintPipeline,
1385+
ZImageControlNetPipeline,
13801386
ZImageImg2ImgPipeline,
13811387
ZImagePipeline,
13821388
)

src/diffusers/loaders/single_file_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
convert_stable_cascade_unet_single_file_to_diffusers,
5050
convert_wan_transformer_to_diffusers,
5151
convert_wan_vae_to_diffusers,
52+
convert_z_image_controlnet_checkpoint_to_diffusers,
5253
convert_z_image_transformer_checkpoint_to_diffusers,
5354
create_controlnet_diffusers_config_from_ldm,
5455
create_unet_diffusers_config_from_ldm,
@@ -172,11 +173,18 @@
172173
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
173174
"default_subfolder": "transformer",
174175
},
176+
"ZImageControlNetModel": {
177+
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
178+
},
175179
}
176180

177181

178182
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
179-
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
183+
model_state_dict_keys = set(model_state_dict.keys())
184+
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
185+
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
186+
is_match = model_state_dict_keys == checkpoint_state_dict_keys
187+
return not (is_subset and is_match)
180188

181189

182190
def _get_single_file_loadable_mapping_class(cls):

src/diffusers/loaders/single_file_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123123
"z-image-turbo": "cap_embedder.0.weight",
124+
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125+
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
124126
"sana": [
125127
"blocks.0.cross_attn.q_linear.weight",
126128
"blocks.0.cross_attn.q_linear.bias",
@@ -220,6 +222,8 @@
220222
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
221223
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
222224
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
225+
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
226+
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
223227
}
224228

225229
# Use to configure model sample size when original config is provided
@@ -779,6 +783,12 @@ def infer_diffusers_model_type(checkpoint):
779783
else:
780784
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
781785

786+
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
787+
model_type = "z-image-turbo-controlnet-2.x"
788+
789+
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
790+
model_type = "z-image-turbo-controlnet"
791+
782792
else:
783793
model_type = "v1"
784794

@@ -3885,3 +3895,17 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38853895
handler_fn_inplace(key, converted_state_dict)
38863896

38873897
return converted_state_dict
3898+
3899+
3900+
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
3901+
if config["add_control_noise_refiner"] is None:
3902+
return checkpoint
3903+
elif config["add_control_noise_refiner"] == "control_noise_refiner":
3904+
return checkpoint
3905+
elif config["add_control_noise_refiner"] == "control_layers":
3906+
converted_state_dict = {
3907+
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
3908+
}
3909+
return converted_state_dict
3910+
else:
3911+
raise ValueError("Unknown Z-Image Turbo ControlNet type.")

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
6767
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
6868
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
69+
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
6970
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
7071
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
7172
_import_structure["embeddings"] = ["ImageProjection"]
@@ -181,6 +182,7 @@
181182
SD3MultiControlNetModel,
182183
SparseControlNetModel,
183184
UNetControlNetXSModel,
185+
ZImageControlNetModel,
184186
)
185187
from .embeddings import ImageProjection
186188
from .modeling_utils import ModelMixin

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from .controlnet_union import ControlNetUnionModel
2121
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
22+
from .controlnet_z_image import ZImageControlNetModel
2223
from .multicontrolnet import MultiControlNetModel
2324
from .multicontrolnet_union import MultiControlNetUnionModel
2425

0 commit comments

Comments
 (0)