Skip to content

Commit 262ce19

Browse files
MatrixTeam-AIPscgylottiyiyixuxugithub-actions[bot]
authored
Feature: Add Mambo-G Guidance as Guider (#12862)
* Feature: Add Mambo-G Guidance to Qwen-Image Pipeline * change to guider implementation * fix copied code residual * Update src/diffusers/guiders/magnitude_aware_guidance.py * Apply style fixes --------- Co-authored-by: Pscgylotti <pscgylotti@github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f7753b1 commit 262ce19

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2626
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
2727
from .guider_utils import BaseGuidance
28+
from .magnitude_aware_guidance import MagnitudeAwareGuidance
2829
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2930
from .skip_layer_guidance import SkipLayerGuidance
3031
from .smoothed_energy_guidance import SmoothedEnergyGuidance
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17+
18+
import torch
19+
20+
from ..configuration_utils import register_to_config
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22+
23+
24+
if TYPE_CHECKING:
25+
from ..modular_pipelines.modular_pipeline import BlockState
26+
27+
28+
class MagnitudeAwareGuidance(BaseGuidance):
29+
"""
30+
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
31+
32+
Args:
33+
guidance_scale (`float`, defaults to `10.0`):
34+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
35+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
36+
deterioration of image quality.
37+
alpha (`float`, defaults to `8.0`):
38+
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
39+
guidance scale when the magnitude of the guidance update is large.
40+
guidance_rescale (`float`, defaults to `0.0`):
41+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
42+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
43+
Flawed](https://huggingface.co/papers/2305.08891).
44+
use_original_formulation (`bool`, defaults to `False`):
45+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
46+
we use the diffusers-native implementation that has been in the codebase for a long time. See
47+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
48+
start (`float`, defaults to `0.0`):
49+
The fraction of the total number of denoising steps after which guidance starts.
50+
stop (`float`, defaults to `1.0`):
51+
The fraction of the total number of denoising steps after which guidance stops.
52+
"""
53+
54+
_input_predictions = ["pred_cond", "pred_uncond"]
55+
56+
@register_to_config
57+
def __init__(
58+
self,
59+
guidance_scale: float = 10.0,
60+
alpha: float = 8.0,
61+
guidance_rescale: float = 0.0,
62+
use_original_formulation: bool = False,
63+
start: float = 0.0,
64+
stop: float = 1.0,
65+
enabled: bool = True,
66+
):
67+
super().__init__(start, stop, enabled)
68+
69+
self.guidance_scale = guidance_scale
70+
self.alpha = alpha
71+
self.guidance_rescale = guidance_rescale
72+
self.use_original_formulation = use_original_formulation
73+
74+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
75+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
76+
data_batches = []
77+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
78+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
79+
data_batches.append(data_batch)
80+
return data_batches
81+
82+
def prepare_inputs_from_block_state(
83+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
84+
) -> List["BlockState"]:
85+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
86+
data_batches = []
87+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
88+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
89+
data_batches.append(data_batch)
90+
return data_batches
91+
92+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
93+
pred = None
94+
95+
if not self._is_mambo_g_enabled():
96+
pred = pred_cond
97+
else:
98+
pred = mambo_guidance(
99+
pred_cond,
100+
pred_uncond,
101+
self.guidance_scale,
102+
self.alpha,
103+
self.use_original_formulation,
104+
)
105+
106+
if self.guidance_rescale > 0.0:
107+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
108+
109+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
110+
111+
@property
112+
def is_conditional(self) -> bool:
113+
return self._count_prepared == 1
114+
115+
@property
116+
def num_conditions(self) -> int:
117+
num_conditions = 1
118+
if self._is_mambo_g_enabled():
119+
num_conditions += 1
120+
return num_conditions
121+
122+
def _is_mambo_g_enabled(self) -> bool:
123+
if not self._enabled:
124+
return False
125+
126+
is_within_range = True
127+
if self._num_inference_steps is not None:
128+
skip_start_step = int(self._start * self._num_inference_steps)
129+
skip_stop_step = int(self._stop * self._num_inference_steps)
130+
is_within_range = skip_start_step <= self._step < skip_stop_step
131+
132+
is_close = False
133+
if self.use_original_formulation:
134+
is_close = math.isclose(self.guidance_scale, 0.0)
135+
else:
136+
is_close = math.isclose(self.guidance_scale, 1.0)
137+
138+
return is_within_range and not is_close
139+
140+
141+
def mambo_guidance(
142+
pred_cond: torch.Tensor,
143+
pred_uncond: torch.Tensor,
144+
guidance_scale: float,
145+
alpha: float = 8.0,
146+
use_original_formulation: bool = False,
147+
):
148+
dim = list(range(1, len(pred_cond.shape)))
149+
diff = pred_cond - pred_uncond
150+
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
151+
guidance_scale_final = (
152+
guidance_scale * torch.exp(-alpha * ratio)
153+
if use_original_formulation
154+
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
155+
)
156+
pred = pred_cond if use_original_formulation else pred_uncond
157+
pred = pred + guidance_scale_final * diff
158+
159+
return pred

0 commit comments

Comments
 (0)