Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import numpy as np
import torch
import torch.distributed as dist
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
Expand Down Expand Up @@ -80,8 +81,10 @@
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -722,6 +725,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1219,7 +1223,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)

if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
Expand Down Expand Up @@ -1507,6 +1515,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)

# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)

text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
Expand Down Expand Up @@ -1536,6 +1559,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
Expand Down Expand Up @@ -1836,15 +1861,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Save the lora layers
accelerator.wait_for_everyone()
is_fsdp = accelerator.state.fsdp_plugin is not None

if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}

transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}

else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

modules_to_save["transformer"] = transformer

Flux2Pipeline.save_lora_weights(
Expand Down
67 changes: 59 additions & 8 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

import numpy as np
import torch
import torch.distributed as dist
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
Expand Down Expand Up @@ -79,8 +80,10 @@
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -691,6 +694,7 @@ def parse_args(input_args=None):

parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1156,7 +1160,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)

if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
Expand Down Expand Up @@ -1430,6 +1438,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)

# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)

text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
Expand Down Expand Up @@ -1461,6 +1484,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
Expand Down Expand Up @@ -1759,15 +1784,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Save the lora layers
accelerator.wait_for_everyone()

if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}

transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}

else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

modules_to_save["transformer"] = transformer

Flux2Pipeline.save_lora_weights(
Expand Down
79 changes: 78 additions & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import re
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
Expand Down Expand Up @@ -394,6 +399,78 @@ def find_nearest_bucket(h, w, bucket_options):
return best_bucket_idx


def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
"""
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
"""

kwargs = {}
fsdp_plugin = accelerator.state.fsdp_plugin

if fsdp_plugin is None:
# FSDP not enabled in Accelerator
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
else:
# FSDP is enabled → use plugin's strategy, or default if None
kwargs["sharding_strategy"] = (
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
)

return kwargs


def wrap_with_fsdp(
model: torch.nn.Module,
device: Union[str, torch.device],
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
) -> FSDP:
"""
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.

Args:
model: Model to wrap
device: Target device (e.g., accelerator.device)
offload: Whether to enable CPU parameter offloading
use_orig_params: Whether to use original parameters
limit_all_gathers: Whether to limit all gathers
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)

Returns:
FSDP-wrapped model
"""

if transformer_layer_cls is None:
# Set the default layers if transformer_layer_cls is not provided
transformer_layer_cls = type(model.model.language_model.layers[0])

# Add auto-wrap policy if transformer layers specified
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={transformer_layer_cls},
)

config = {
"device_id": device,
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
"use_orig_params": use_orig_params,
"limit_all_gathers": limit_all_gathers,
"auto_wrap_policy": auto_wrap_policy
}

if fsdp_kwargs:
config.update(fsdp_kwargs)

fsdp_model = FSDP(model, **config)
if dist.is_initialized():
dist.barrier()
return fsdp_model


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Expand Down