Skip to content

Conversation

@Ednaordinary
Copy link
Contributor

@Ednaordinary Ednaordinary commented Dec 17, 2025

What does this PR do?

Adds support for Chroma Radiance v0.4 and x0 (not v0.3). Do note a large portion of this is just retyped from the flow repo and comfyui pr

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 @yiyixuxu

@Ednaordinary
Copy link
Contributor Author

Ednaordinary commented Dec 17, 2025

Currently close to working but not quite there. I'm unsure if this is an issue with the scheduler/transformer config (which I've messed with) or my code.

(x0)

chroma0

reference with same settings

scheduler config

{
  "_class_name": "FlowMatchHeunDiscreteScheduler",
  "_diffusers_version": "0.36.0.dev0",
  "base_image_seq_len": 256,
  "base_shift": 0.5,
  "invert_sigmas": false,
  "max_image_seq_len": 4096,
  "max_shift": 1.15,
  "num_train_timesteps": 1000,
  "shift": 3.0,
  "shift_terminal": null,
  "stochastic_sampling": false,
  "time_shift_type": "exponential",
  "use_beta_sigmas": true,
  "use_dynamic_shifting": true,
  "use_exponential_sigmas": false,
  "use_karras_sigmas": false
}

transformer config

{
  "_class_name": "ChromaRadianceTransformer2DModel",
  "_diffusers_version": "0.36.0.dev0",
  "approximator_hidden_dim": 5120,
  "approximator_layers": 5,
  "approximator_num_channels": 64,
  "attention_head_dim": 128,
  "axes_dims_rope": [
    16,
    56,
    56
  ],
  "guidance_embeds": false,
  "in_channels": 3,
  "joint_attention_dim": 4096,
  "nerf_hidden_dim": 64,
  "nerf_layers": 4,
  "nerf_max_freqs": 8,
  "nerf_mlp_ratio": 4,
  "num_attention_heads": 24,
  "num_layers": 19,
  "num_single_layers": 38,
  "out_channels": null,
  "patch_size": 16,
  "pooled_projection_dim": 768,
  "x0": true
}

save script (x0)

import torch

from diffusers import ChromaRadianceTransformer2DModel
from diffusers import ChromaRadiancePipeline

from transformers import T5EncoderModel, T5Tokenizer

chroma_transformer = ChromaRadianceTransformer2DModel.from_single_file("latest_x0.safetensors", torch_dtype=torch.bfloat16, in_channels=3, patch_size=16)

text_encoder = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
tokenizer = T5Tokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
pipe = ChromaRadiancePipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=chroma_transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)

pipe.save_pretrained("Chroma1-Radiance-x0")

@Ednaordinary
Copy link
Contributor Author

chroma0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant