Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/guiders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .magnitude_aware_guidance import MagnitudeAwareGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
Expand Down
159 changes: 159 additions & 0 deletions src/diffusers/guiders/magnitude_aware_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState


class MagnitudeAwareGuidance(BaseGuidance):
"""
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442

Args:
guidance_scale (`float`, defaults to `10.0`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
alpha (`float`, defaults to `8.0`):
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
guidance scale when the magnitude of the guidance update is large.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 10.0,
alpha: float = 8.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)

self.guidance_scale = guidance_scale
self.alpha = alpha
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation

def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_mambo_g_enabled():
pred = pred_cond
else:
pred = mambo_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.alpha,
self.use_original_formulation,
)

if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
return self._count_prepared == 1

@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_mambo_g_enabled():
num_conditions += 1
return num_conditions

def _is_mambo_g_enabled(self) -> bool:
if not self._enabled:
return False

is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step

is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)

return is_within_range and not is_close


def mambo_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
alpha: float = 8.0,
use_original_formulation: bool = False,
):
dim = list(range(1, len(pred_cond.shape)))
diff = pred_cond - pred_uncond
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
guidance_scale_final = (
guidance_scale * torch.exp(-alpha * ratio)
if use_original_formulation
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
)
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale_final * diff

return pred
Loading