diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 4e53c373c4f4..58ad0c211b64 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -25,6 +25,7 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .guider_utils import BaseGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py new file mode 100644 index 000000000000..b81cf0d3a1f9 --- /dev/null +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class MagnitudeAwareGuidance(BaseGuidance): + """ + Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442 + + Args: + guidance_scale (`float`, defaults to `10.0`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + alpha (`float`, defaults to `8.0`): + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of + guidance scale when the magnitude of the guidance update is large. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 10.0, + alpha: float = 8.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + enabled: bool = True, + ): + super().__init__(start, stop, enabled) + + self.guidance_scale = guidance_scale + self.alpha = alpha + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch(data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: + pred = None + + if not self._is_mambo_g_enabled(): + pred = pred_cond + else: + pred = mambo_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.alpha, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_mambo_g_enabled(): + num_conditions += 1 + return num_conditions + + def _is_mambo_g_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def mambo_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + alpha: float = 8.0, + use_original_formulation: bool = False, +): + dim = list(range(1, len(pred_cond.shape))) + diff = pred_cond - pred_uncond + ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) + guidance_scale_final = ( + guidance_scale * torch.exp(-alpha * ratio) + if use_original_formulation + else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio) + ) + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale_final * diff + + return pred