Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"ChromaRadianceTransformer2DModel",
"ChronoEditTransformer3DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
Expand Down Expand Up @@ -451,6 +452,7 @@
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaRadiancePipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
Expand Down Expand Up @@ -943,6 +945,7 @@
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaRadianceTransformer2DModel,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
Expand Down Expand Up @@ -1160,6 +1163,7 @@
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
ChromaRadiancePipeline,
ChronoEditPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChromaRadianceTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_radiance_transformer_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
Expand Down Expand Up @@ -111,6 +112,10 @@
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaRadianceTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_radiance_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
Expand Down
197 changes: 196 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
"chroma": {"pretrained_model_name_or_path": "lodestones/Chroma1-HD"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
Expand Down Expand Up @@ -664,7 +665,13 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-2-dev"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
if any(c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"]):
# Should be updated once a repo exists
# if any(h in checkpoint for h in ["nerf_blocks.0.param_generator.bias"]):
# model_type = "chroma-radiance"
# else:
model_type = "chroma"
elif any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if "model.diffusion_model.img_in.weight" in checkpoint:
Expand Down Expand Up @@ -3556,6 +3563,194 @@ def swap_scale_shift(weight):
return converted_state_dict


def convert_chroma_radiance_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())

for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = (
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
)
num_nerf_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "nerf_blocks." in k))[-1] + 1 # noqa: C401
mlp_ratio = 4.0
inner_dim = 3072

# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

# guidance
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"distilled_guidance_layer.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)

# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)

# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")

# nerf

converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop("nerf_image_embedder.embedder.0.bias")
converted_state_dict["nerf.nerf_embedder.embedder.0.weight"] = checkpoint.pop(
"nerf_image_embedder.embedder.0.weight"
)
converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop("nerf_final_layer_conv.conv.bias")
converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop("nerf_final_layer_conv.conv.weight")
converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop("nerf_final_layer_conv.norm.scale")

for i in range(num_nerf_layers):
block_prefix = f"nerf.blocks.{i}."
converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop(f"nerf_blocks.{i}.norm.scale")
converted_state_dict[f"{block_prefix}param_generator.bias"] = checkpoint.pop(
f"nerf_blocks.{i}.param_generator.bias"
)
converted_state_dict[f"{block_prefix}param_generator.weight"] = checkpoint.pop(
f"nerf_blocks.{i}.param_generator.weight"
)

# patch

converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop("img_in_patch.bias")
converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop("img_in_patch.weight")

return converted_state_dict


def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = [
"ChromaTransformer2DModel",
"ChromaRadianceTransformer2DModel",
]
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
Expand Down Expand Up @@ -189,6 +192,7 @@
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaRadianceTransformer2DModel,
ChromaTransformer2DModel,
ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
Expand Down
Loading