@@ -744,6 +744,8 @@ namespace Flux {
744744 int64_t nerf_mlp_ratio = 4 ;
745745 int64_t nerf_depth = 4 ;
746746 int64_t nerf_max_freqs = 8 ;
747+ bool use_x0 = false ;
748+ bool use_patch_size_32 = false ;
747749 };
748750
749751 struct FluxParams {
@@ -781,7 +783,7 @@ namespace Flux {
781783 Flux (FluxParams params)
782784 : params(params) {
783785 if (params.version == VERSION_CHROMA_RADIANCE) {
784- std::pair<int , int > kernel_size = {( int )params. patch_size , ( int )params. patch_size };
786+ std::pair<int , int > kernel_size = {16 , 16 };
785787 std::pair<int , int > stride = kernel_size;
786788
787789 blocks[" img_in_patch" ] = std::make_shared<Conv2d>(params.in_channels ,
@@ -1044,6 +1046,15 @@ namespace Flux {
10441046 return img;
10451047 }
10461048
1049+ struct ggml_tensor * _apply_x0_residual (GGMLRunnerContext* ctx,
1050+ struct ggml_tensor * predicted,
1051+ struct ggml_tensor * noisy,
1052+ struct ggml_tensor * timesteps) {
1053+ auto x = ggml_sub (ctx->ggml_ctx , noisy, predicted);
1054+ x = ggml_div (ctx->ggml_ctx , x, timesteps);
1055+ return x;
1056+ }
1057+
10471058 struct ggml_tensor * forward_chroma_radiance (GGMLRunnerContext* ctx,
10481059 struct ggml_tensor * x,
10491060 struct ggml_tensor * timestep,
@@ -1068,6 +1079,13 @@ namespace Flux {
10681079 auto img = pad_to_patch_size (ctx->ggml_ctx , x);
10691080 auto orig_img = img;
10701081
1082+ if (params.chroma_radiance_params .use_patch_size_32 ) {
1083+ // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
1084+ // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
1085+ // img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
1086+ img = ggml_interpolate (ctx->ggml_ctx , img, W / 2 , H / 2 , C, x->ne [3 ], GGML_SCALE_MODE_BILINEAR);
1087+ }
1088+
10711089 auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks[" img_in_patch" ]);
10721090
10731091 img = img_in_patch->forward (ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
@@ -1104,6 +1122,10 @@ namespace Flux {
11041122
11051123 out = nerf_final_layer_conv->forward (ctx, img_dct); // [N, C, H, W]
11061124
1125+ if (params.chroma_radiance_params .use_x0 ) {
1126+ out = _apply_x0_residual (ctx, out, orig_img, timestep);
1127+ }
1128+
11071129 return out;
11081130 }
11091131
@@ -1290,6 +1312,15 @@ namespace Flux {
12901312 // not schnell
12911313 flux_params.guidance_embed = true ;
12921314 }
1315+ if (tensor_name.find (" __x0__" ) != std::string::npos) {
1316+ LOG_DEBUG (" using x0 prediction" );
1317+ flux_params.chroma_radiance_params .use_x0 = true ;
1318+ }
1319+ if (tensor_name.find (" __32x32__" ) != std::string::npos) {
1320+ LOG_DEBUG (" using patch size 32 prediction" );
1321+ flux_params.chroma_radiance_params .use_patch_size_32 = true ;
1322+ flux_params.patch_size = 32 ;
1323+ }
12931324 if (tensor_name.find (" distilled_guidance_layer.in_proj.weight" ) != std::string::npos) {
12941325 // Chroma
12951326 flux_params.is_chroma = true ;
0 commit comments