From c2d5309c7df31ece0a2e97633c18e8291188e612 Mon Sep 17 00:00:00 2001 From: Pscgylotti Date: Thu, 18 Dec 2025 23:42:36 +0800 Subject: [PATCH 1/5] Feature: Add Mambo-G Guidance to Qwen-Image Pipeline --- .../pipelines/qwenimage/pipeline_qwenimage.py | 39 ++++++++++++++++-- .../pipeline_qwenimage_controlnet.py | 41 ++++++++++++++++--- .../pipeline_qwenimage_controlnet_inpaint.py | 38 +++++++++++++++-- .../qwenimage/pipeline_qwenimage_edit.py | 41 ++++++++++++++++--- .../pipeline_qwenimage_edit_inpaint.py | 41 ++++++++++++++++--- .../qwenimage/pipeline_qwenimage_edit_plus.py | 41 ++++++++++++++++--- .../qwenimage/pipeline_qwenimage_img2img.py | 41 ++++++++++++++++--- .../qwenimage/pipeline_qwenimage_inpaint.py | 41 ++++++++++++++++--- 8 files changed, 286 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..1c9fcfea4fc0 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -473,6 +473,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -548,6 +550,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -621,6 +632,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents = self.prepare_latents( @@ -713,11 +729,26 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..06784fada4f0 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -577,6 +577,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -652,6 +654,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -736,6 +747,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, QwenImageControlNetModel): @@ -940,11 +956,26 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..5d21cb5bbee7 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -618,6 +618,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -684,6 +686,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -757,6 +768,11 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 @@ -885,9 +901,25 @@ def __call__( )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..2bd2c6dadd66 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -569,6 +569,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -650,6 +652,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -736,6 +747,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -841,11 +857,26 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..425418fb480d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -706,6 +706,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -810,6 +812,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -917,6 +928,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -1055,11 +1071,26 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..9970b72db16e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -538,6 +538,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -619,6 +621,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -717,6 +728,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -825,11 +841,26 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..db6528773469 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -549,6 +549,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -636,6 +638,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -714,6 +725,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -815,11 +831,26 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..e55c15915fd5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -662,6 +662,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + cfg_type: Optional[str] = "original", + cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -766,6 +768,15 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cfg_type (`str`, *optional*, defaults to `"original"`): + The specified classifier-free-guidance (CFG) type for inference. + `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. + `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. + cfg_kwargs (`dict`, *optional*): + A kwargs dictionary for additional cfg hyperparameters. + For `"mambo_g"` + - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` + - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -858,6 +869,11 @@ def __call__( max_sequence_length=max_sequence_length, ) + if cfg_type != "original" and not do_true_cfg: + logger.warning( + f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." + ) + # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -984,11 +1000,26 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + + if cfg_type == "original": + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + elif cfg_type == "mambo_g": + mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) + ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) + guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) + if apply_cfg_rescale: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + raise ValueError( + f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 59a1c42d98ca4e634e8c39ce625f56e06c6df0e7 Mon Sep 17 00:00:00 2001 From: Pscgylotti Date: Fri, 19 Dec 2025 12:42:58 +0800 Subject: [PATCH 2/5] change to guider implementation --- src/diffusers/guiders/__init__.py | 1 + .../guiders/magnitude_aware_guidance.py | 155 ++++++++++++++++++ .../pipelines/qwenimage/pipeline_qwenimage.py | 39 +---- .../pipeline_qwenimage_controlnet.py | 41 +---- .../pipeline_qwenimage_controlnet_inpaint.py | 38 +---- .../qwenimage/pipeline_qwenimage_edit.py | 41 +---- .../pipeline_qwenimage_edit_inpaint.py | 41 +---- .../qwenimage/pipeline_qwenimage_edit_plus.py | 41 +---- .../qwenimage/pipeline_qwenimage_img2img.py | 41 +---- .../qwenimage/pipeline_qwenimage_inpaint.py | 41 +---- 10 files changed, 193 insertions(+), 286 deletions(-) create mode 100644 src/diffusers/guiders/magnitude_aware_guidance.py diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 4e53c373c4f4..cdb80014194a 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -29,3 +29,4 @@ from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py new file mode 100644 index 000000000000..8b4864bf63c3 --- /dev/null +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -0,0 +1,155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class MagnitudeAwareGuidance(BaseGuidance): + """ + Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442 + + Args: + guidance_scale (`float`, defaults to `10.0`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + alpha (`float`, defaults to `8.0`): + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of guidance scale when the magnitude of the guidance update is large. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 10.0, + alpha: float = 8.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + enabled: bool = True, + ): + super().__init__(start, stop, enabled) + + self.guidance_scale = guidance_scale + self.alpha = alpha + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch(data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: + pred = None + + if not self._is_mambo_g_enabled(): + pred = pred_cond + else: + pred = mambo_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.alpha, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_mambo_g_enabled(): + num_conditions += 1 + return num_conditions + + def _is_mambo_g_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def mambo_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + alpha: float = 8.0, + use_original_formulation: bool = False, +): + dim = [i for i in range(1, len(pred_cond.shape))] + diff = pred_cond - pred_uncond + ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) + guidance_scale_final = guidance_scale * torch.exp(- alpha * ratio) if use_original_formulation else 1.0 + (guidance_scale - 1.0) * torch.exp(- alpha * ratio) + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale_final * diff + + return pred diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 1c9fcfea4fc0..33dc2039b986 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -473,8 +473,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -550,15 +548,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -632,11 +621,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents = self.prepare_latents( @@ -729,26 +713,11 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 06784fada4f0..5111096d93c1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -577,8 +577,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -654,15 +652,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -747,11 +736,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, QwenImageControlNetModel): @@ -956,26 +940,11 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 5d21cb5bbee7..102a813ab582 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -618,8 +618,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -686,15 +684,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -768,11 +757,6 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 @@ -901,25 +885,9 @@ def __call__( )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 2bd2c6dadd66..ed37b238c8c9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -569,8 +569,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -652,15 +650,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -747,11 +736,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -857,26 +841,11 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 425418fb480d..d54d1881fa4e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -706,8 +706,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -812,15 +810,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -928,11 +917,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -1071,26 +1055,11 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 9970b72db16e..ec203edf166c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -538,8 +538,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -621,15 +619,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -728,11 +717,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -841,26 +825,11 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index db6528773469..cb4c5d8016bb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -549,8 +549,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -638,15 +636,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -725,11 +714,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -831,26 +815,11 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index e55c15915fd5..1915c27eb2bb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -662,8 +662,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - cfg_type: Optional[str] = "original", - cfg_kwargs: Optional[Dict[str, Any]] = {}, ): r""" Function invoked when calling the pipeline for generation. @@ -768,15 +766,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - cfg_type (`str`, *optional*, defaults to `"original"`): - The specified classifier-free-guidance (CFG) type for inference. - `"original"` for the original cfg method in Classifier-Free Diffusion Guidance. - `"mambo_g"` for the optional cfg method in MAMBO-G: Magnitude-Aware Mitigation for Boosted Guidance. - cfg_kwargs (`dict`, *optional*): - A kwargs dictionary for additional cfg hyperparameters. - For `"mambo_g"` - - alpha : float, Decaying weight for MAMBO-G, defaults to `8.0` - - cfg_rescale : bool, Whether to use cfg rescaling, defaults to `True`. Examples: @@ -869,11 +858,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - if cfg_type != "original" and not do_true_cfg: - logger.warning( - f"cfg_type is passed as {cfg_type}, but classifier-free guidance is not enabled since do_true_cfg is passed as False." - ) - # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -1000,26 +984,11 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if cfg_type == "original": - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - elif cfg_type == "mambo_g": - mambo_g_alpha = cfg_kwargs.get("alpha", 8.0) - ratio = torch.norm(noise_pred - neg_noise_pred) / torch.norm(neg_noise_pred) - guidance_scale = 1 + (true_cfg_scale - 1.0) * torch.exp(- mambo_g_alpha * ratio) - comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - apply_cfg_rescale = cfg_kwargs.get("cfg_rescale", True) - if apply_cfg_rescale: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - raise ValueError( - f"cfg_type given as {cfg_type} must be one of `original` or `mambo_g` for the QwenImagePipeline." - ) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 9d2c091fd03a52bece2cf02f40a78a47ecb6388a Mon Sep 17 00:00:00 2001 From: Pscgylotti Date: Fri, 19 Dec 2025 16:00:38 +0800 Subject: [PATCH 3/5] fix copied code residual --- src/diffusers/guiders/magnitude_aware_guidance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py index 8b4864bf63c3..b01c153d14e6 100644 --- a/src/diffusers/guiders/magnitude_aware_guidance.py +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -69,7 +69,6 @@ def __init__( self.alpha = alpha self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - self.momentum_buffer = None def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] From 3266584fd45c8df844fcdb075a47afc26c27604a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 19 Dec 2025 10:21:57 -1000 Subject: [PATCH 4/5] Update src/diffusers/guiders/magnitude_aware_guidance.py --- src/diffusers/guiders/magnitude_aware_guidance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py index b01c153d14e6..cd2de2ac7034 100644 --- a/src/diffusers/guiders/magnitude_aware_guidance.py +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -144,7 +144,7 @@ def mambo_guidance( alpha: float = 8.0, use_original_formulation: bool = False, ): - dim = [i for i in range(1, len(pred_cond.shape))] + dim = list(range(1, len(pred_cond.shape))) diff = pred_cond - pred_uncond ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) guidance_scale_final = guidance_scale * torch.exp(- alpha * ratio) if use_original_formulation else 1.0 + (guidance_scale - 1.0) * torch.exp(- alpha * ratio) From c3bbdc9ba7ef256939011397a1e169339a4a9f46 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 19 Dec 2025 20:23:07 +0000 Subject: [PATCH 5/5] Apply style fixes --- src/diffusers/guiders/__init__.py | 2 +- src/diffusers/guiders/magnitude_aware_guidance.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index cdb80014194a..58ad0c211b64 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -25,8 +25,8 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .guider_utils import BaseGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - from .magnitude_aware_guidance import MagnitudeAwareGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py index cd2de2ac7034..b81cf0d3a1f9 100644 --- a/src/diffusers/guiders/magnitude_aware_guidance.py +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -35,7 +35,8 @@ class MagnitudeAwareGuidance(BaseGuidance): prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and deterioration of image quality. alpha (`float`, defaults to `8.0`): - The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of guidance scale when the magnitude of the guidance update is large. + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of + guidance scale when the magnitude of the guidance update is large. guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are @@ -147,7 +148,11 @@ def mambo_guidance( dim = list(range(1, len(pred_cond.shape))) diff = pred_cond - pred_uncond ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) - guidance_scale_final = guidance_scale * torch.exp(- alpha * ratio) if use_original_formulation else 1.0 + (guidance_scale - 1.0) * torch.exp(- alpha * ratio) + guidance_scale_final = ( + guidance_scale * torch.exp(-alpha * ratio) + if use_original_formulation + else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio) + ) pred = pred_cond if use_original_formulation else pred_uncond pred = pred + guidance_scale_final * diff