From e1cf6a3bb68fb2698bb750140b67081d5f570855 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 10:58:07 +0800 Subject: [PATCH 1/5] Bugfix for dreambooth flux2 img2img2 --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 0b9b9f993094..f855add44a94 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1621,7 +1621,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input]).to( device=cond_model_input.device ) @@ -1650,6 +1650,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + noisy_len = packed_noisy_model_input.shape[1] + # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) @@ -1668,7 +1670,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + model_pred = model_pred[:, : noisy_len:] + model_input_ids = model_input_ids[:, : noisy_len:] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) From e3556db670044bfc2cb173a551b5aca8a029732b Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 11:24:30 +0800 Subject: [PATCH 2/5] Bugfix for dreambooth flux2 img2img2 --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index f855add44a94..42be2cdcdfc2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1670,8 +1670,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : noisy_len:] - model_input_ids = model_input_ids[:, : noisy_len:] + model_pred = model_pred[:, : noisy_len :] + model_input_ids = model_input_ids[:, : noisy_len :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) From 79f704c53e061d9259b13ee5aa7ae319ac803bf8 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 11:27:26 +0800 Subject: [PATCH 3/5] Bugfix for dreambooth flux2 img2img2 --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 42be2cdcdfc2..2dc03ad8c05d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1670,8 +1670,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : noisy_len :] - model_input_ids = model_input_ids[:, : noisy_len :] + model_pred = model_pred[:, :noisy_len:] + model_input_ids = model_input_ids[:, :noisy_len:] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) From e5229dab89b6b3d5ab5311b6e84a3a2e4a243523 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 16:48:17 +0800 Subject: [PATCH 4/5] Bugfix for dreambooth flux2 img2img2 --- .../dreambooth/train_dreambooth_lora_flux2_img2img.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 2dc03ad8c05d..e4400506aaeb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1621,9 +1621,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input]).to( + cond_model_input_list = [ + cond_model_input[i].unsqueeze(0) + for i in range(cond_model_input.shape[0]) + ] + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], + -1, + model_input_ids.shape[-1] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) From 86da067adb33b6a7e4ae4c433a8c7438a0331134 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 18 Dec 2025 17:24:29 +0800 Subject: [PATCH 5/5] Bugfix for dreambooth flux2 img2img2 --- .../dreambooth/train_dreambooth_lora_flux2_img2img.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index e4400506aaeb..359e61cf54bc 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1621,17 +1621,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [ - cond_model_input[i].unsqueeze(0) - for i in range(cond_model_input.shape[0]) - ] + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], - -1, - model_input_ids.shape[-1] + cond_model_input.shape[0], -1, model_input_ids.shape[-1] ) # Sample noise that we'll add to the latents