Skip to content

Commit 23fce0b

Browse files
stduhpfleejet
andauthored
feat: add support for Chroma Radiance x0 (#1091)
* Add x0 Flux pred (+prepare for others) * Fix convert models with empty tensors * patch_32 exp support attempt * improve support for patch_32 * follow official pipeline --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 7c88c47 commit 23fce0b

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

flux.hpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,8 @@ namespace Flux {
744744
int64_t nerf_mlp_ratio = 4;
745745
int64_t nerf_depth = 4;
746746
int64_t nerf_max_freqs = 8;
747+
bool use_x0 = false;
748+
bool use_patch_size_32 = false;
747749
};
748750

749751
struct FluxParams {
@@ -781,7 +783,7 @@ namespace Flux {
781783
Flux(FluxParams params)
782784
: params(params) {
783785
if (params.version == VERSION_CHROMA_RADIANCE) {
784-
std::pair<int, int> kernel_size = {(int)params.patch_size, (int)params.patch_size};
786+
std::pair<int, int> kernel_size = {16, 16};
785787
std::pair<int, int> stride = kernel_size;
786788

787789
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
@@ -1044,6 +1046,15 @@ namespace Flux {
10441046
return img;
10451047
}
10461048

1049+
struct ggml_tensor* _apply_x0_residual(GGMLRunnerContext* ctx,
1050+
struct ggml_tensor* predicted,
1051+
struct ggml_tensor* noisy,
1052+
struct ggml_tensor* timesteps) {
1053+
auto x = ggml_sub(ctx->ggml_ctx, noisy, predicted);
1054+
x = ggml_div(ctx->ggml_ctx, x, timesteps);
1055+
return x;
1056+
}
1057+
10471058
struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx,
10481059
struct ggml_tensor* x,
10491060
struct ggml_tensor* timestep,
@@ -1068,6 +1079,13 @@ namespace Flux {
10681079
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
10691080
auto orig_img = img;
10701081

1082+
if (params.chroma_radiance_params.use_patch_size_32) {
1083+
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
1084+
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
1085+
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
1086+
img = ggml_interpolate(ctx->ggml_ctx, img, W / 2, H / 2, C, x->ne[3], GGML_SCALE_MODE_BILINEAR);
1087+
}
1088+
10711089
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
10721090

10731091
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
@@ -1104,6 +1122,10 @@ namespace Flux {
11041122

11051123
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
11061124

1125+
if (params.chroma_radiance_params.use_x0) {
1126+
out = _apply_x0_residual(ctx, out, orig_img, timestep);
1127+
}
1128+
11071129
return out;
11081130
}
11091131

@@ -1290,6 +1312,15 @@ namespace Flux {
12901312
// not schnell
12911313
flux_params.guidance_embed = true;
12921314
}
1315+
if (tensor_name.find("__x0__") != std::string::npos) {
1316+
LOG_DEBUG("using x0 prediction");
1317+
flux_params.chroma_radiance_params.use_x0 = true;
1318+
}
1319+
if (tensor_name.find("__32x32__") != std::string::npos) {
1320+
LOG_DEBUG("using patch size 32 prediction");
1321+
flux_params.chroma_radiance_params.use_patch_size_32 = true;
1322+
flux_params.patch_size = 32;
1323+
}
12931324
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
12941325
// Chroma
12951326
flux_params.is_chroma = true;

model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,13 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
17371737
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
17381738
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
17391739

1740+
if (!tensor->data) {
1741+
GGML_ASSERT(ggml_nelements(tensor) == 0);
1742+
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
1743+
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
1744+
tensor->data = ggml_get_mem_buffer(ggml_ctx);
1745+
}
1746+
17401747
*dst_tensor = tensor;
17411748

17421749
gguf_add_tensor(gguf_ctx, tensor);

stable-diffusion.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,8 @@ class StableDiffusionGGML {
708708
if (stacked_id) {
709709
ignore_tensors.insert("pmid.unet.");
710710
}
711+
ignore_tensors.insert("model.diffusion_model.__x0__");
712+
ignore_tensors.insert("model.diffusion_model.__32x32__");
711713

712714
if (vae_decode_only) {
713715
ignore_tensors.insert("first_stage_model.encoder");
@@ -842,6 +844,7 @@ class StableDiffusionGGML {
842844
}
843845
} else if (sd_version_is_flux(version)) {
844846
pred_type = FLUX_FLOW_PRED;
847+
845848
if (flow_shift == INFINITY) {
846849
flow_shift = 1.0f; // TODO: validate
847850
for (const auto& [name, tensor_storage] : tensor_storage_map) {

0 commit comments

Comments
 (0)