Skip to content

Commit b49c946

Browse files
committed
Add FSDP option for Flux2
1 parent 55463f7 commit b49c946

File tree

3 files changed

+197
-17
lines changed

3 files changed

+197
-17
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
import numpy as np
4949
import torch
50+
import torch.distributed as dist
5051
import transformers
5152
from accelerate import Accelerator
5253
from accelerate.logging import get_logger
@@ -80,8 +81,10 @@
8081
compute_loss_weighting_for_sd3,
8182
find_nearest_bucket,
8283
free_memory,
84+
get_fsdp_kwargs_from_accelerator,
8385
offload_models,
8486
parse_buckets_string,
87+
wrap_with_fsdp,
8588
)
8689
from diffusers.utils import (
8790
check_min_version,
@@ -722,6 +725,7 @@ def parse_args(input_args=None):
722725
)
723726
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
724727
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
728+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
725729

726730
if input_args is not None:
727731
args = parser.parse_args(input_args)
@@ -1219,7 +1223,11 @@ def main(args):
12191223
if args.bnb_quantization_config_path is not None
12201224
else {"device": accelerator.device, "dtype": weight_dtype}
12211225
)
1222-
transformer.to(**transformer_to_kwargs)
1226+
1227+
is_fsdp = accelerator.state.fsdp_plugin is not None
1228+
if not is_fsdp:
1229+
transformer.to(**transformer_to_kwargs)
1230+
12231231
if args.do_fp8_training:
12241232
convert_to_float8_training(
12251233
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1507,6 +1515,21 @@ def _encode_single(prompt: str):
15071515
args.validation_prompt, text_encoding_pipeline
15081516
)
15091517

1518+
# Init FSDP for text encoder
1519+
if args.fsdp_text_encoder:
1520+
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
1521+
text_encoder_fsdp = FSDP(
1522+
model=text_encoding_pipeline.text_encoder,
1523+
device_id=accelerator.device,
1524+
cpu_offload=args.offload,
1525+
limit_all_gathers=True,
1526+
use_orig_params=True,
1527+
fsdp_kwargs=fsdp_kwargs,
1528+
)
1529+
1530+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1531+
dist.barrier()
1532+
15101533
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
15111534
# pack the statically computed variables appropriately here. This is so that we don't
15121535
# have to pass them to the dataloader.
@@ -1536,6 +1559,8 @@ def _encode_single(prompt: str):
15361559
if train_dataset.custom_instance_prompts:
15371560
if args.remote_text_encoder:
15381561
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1562+
elif args.fsdp_text_encoder:
1563+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
15391564
else:
15401565
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
15411566
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1836,15 +1861,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18361861

18371862
# Save the lora layers
18381863
accelerator.wait_for_everyone()
1864+
is_fsdp = accelerator.state.fsdp_plugin is not None
1865+
1866+
if is_fsdp:
1867+
transformer = unwrap_model(transformer)
1868+
state_dict = accelerator.get_state_dict(transformer)
18391869
if accelerator.is_main_process:
18401870
modules_to_save = {}
1841-
transformer = unwrap_model(transformer)
1842-
if args.bnb_quantization_config_path is None:
1843-
if args.upcast_before_saving:
1844-
transformer.to(torch.float32)
1845-
else:
1846-
transformer = transformer.to(weight_dtype)
1847-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1871+
if is_fsdp:
1872+
if args.bnb_quantization_config_path is None:
1873+
if args.upcast_before_saving:
1874+
state_dict = {
1875+
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1876+
}
1877+
else:
1878+
state_dict = {
1879+
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1880+
}
1881+
1882+
transformer_lora_layers = get_peft_model_state_dict(
1883+
transformer,
1884+
state_dict=state_dict,
1885+
)
1886+
transformer_lora_layers = {
1887+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1888+
for k, v in transformer_lora_layers.items()
1889+
}
1890+
1891+
else:
1892+
transformer = unwrap_model(transformer)
1893+
if args.bnb_quantization_config_path is None:
1894+
if args.upcast_before_saving:
1895+
transformer.to(torch.float32)
1896+
else:
1897+
transformer = transformer.to(weight_dtype)
1898+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1899+
18481900
modules_to_save["transformer"] = transformer
18491901

18501902
Flux2Pipeline.save_lora_weights(

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
import numpy as np
4848
import torch
49+
import torch.distributed as dist
4950
import transformers
5051
from accelerate import Accelerator
5152
from accelerate.logging import get_logger
@@ -79,8 +80,10 @@
7980
compute_loss_weighting_for_sd3,
8081
find_nearest_bucket,
8182
free_memory,
83+
get_fsdp_kwargs_from_accelerator,
8284
offload_models,
8385
parse_buckets_string,
86+
wrap_with_fsdp,
8487
)
8588
from diffusers.utils import (
8689
check_min_version,
@@ -691,6 +694,7 @@ def parse_args(input_args=None):
691694

692695
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
693696
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
697+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
694698

695699
if input_args is not None:
696700
args = parser.parse_args(input_args)
@@ -1156,7 +1160,11 @@ def main(args):
11561160
if args.bnb_quantization_config_path is not None
11571161
else {"device": accelerator.device, "dtype": weight_dtype}
11581162
)
1159-
transformer.to(**transformer_to_kwargs)
1163+
1164+
is_fsdp = accelerator.state.fsdp_plugin is not None
1165+
if not is_fsdp:
1166+
transformer.to(**transformer_to_kwargs)
1167+
11601168
if args.do_fp8_training:
11611169
convert_to_float8_training(
11621170
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1430,6 +1438,21 @@ def _encode_single(prompt: str):
14301438
args.validation_prompt, text_encoding_pipeline
14311439
)
14321440

1441+
# Init FSDP for text encoder
1442+
if args.fsdp_text_encoder:
1443+
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
1444+
text_encoder_fsdp = FSDP(
1445+
model=text_encoding_pipeline.text_encoder,
1446+
device_id=accelerator.device,
1447+
cpu_offload=args.offload,
1448+
limit_all_gathers=True,
1449+
use_orig_params=True,
1450+
fsdp_kwargs=fsdp_kwargs,
1451+
)
1452+
1453+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1454+
dist.barrier()
1455+
14331456
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14341457
# pack the statically computed variables appropriately here. This is so that we don't
14351458
# have to pass them to the dataloader.
@@ -1461,6 +1484,8 @@ def _encode_single(prompt: str):
14611484
if train_dataset.custom_instance_prompts:
14621485
if args.remote_text_encoder:
14631486
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1487+
elif args.fsdp_text_encoder:
1488+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
14641489
else:
14651490
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
14661491
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1759,15 +1784,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17591784

17601785
# Save the lora layers
17611786
accelerator.wait_for_everyone()
1787+
1788+
if is_fsdp:
1789+
transformer = unwrap_model(transformer)
1790+
state_dict = accelerator.get_state_dict(transformer)
17621791
if accelerator.is_main_process:
17631792
modules_to_save = {}
1764-
transformer = unwrap_model(transformer)
1765-
if args.bnb_quantization_config_path is None:
1766-
if args.upcast_before_saving:
1767-
transformer.to(torch.float32)
1768-
else:
1769-
transformer = transformer.to(weight_dtype)
1770-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1793+
if is_fsdp:
1794+
if args.bnb_quantization_config_path is None:
1795+
if args.upcast_before_saving:
1796+
state_dict = {
1797+
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1798+
}
1799+
else:
1800+
state_dict = {
1801+
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1802+
}
1803+
1804+
transformer_lora_layers = get_peft_model_state_dict(
1805+
transformer,
1806+
state_dict=state_dict,
1807+
)
1808+
transformer_lora_layers = {
1809+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1810+
for k, v in transformer_lora_layers.items()
1811+
}
1812+
1813+
else:
1814+
transformer = unwrap_model(transformer)
1815+
if args.bnb_quantization_config_path is None:
1816+
if args.upcast_before_saving:
1817+
transformer.to(torch.float32)
1818+
else:
1819+
transformer = transformer.to(weight_dtype)
1820+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1821+
17711822
modules_to_save["transformer"] = transformer
17721823

17731824
Flux2Pipeline.save_lora_weights(

src/diffusers/training_utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
import re
77
import warnings
88
from contextlib import contextmanager
9-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
9+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
1010

1111
import numpy as np
1212
import torch
13+
import torch.distributed as dist
14+
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
17+
from functools import partial
1318

1419
from .models import UNet2DConditionModel
1520
from .pipelines import DiffusionPipeline
@@ -394,6 +399,78 @@ def find_nearest_bucket(h, w, bucket_options):
394399
return best_bucket_idx
395400

396401

402+
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
403+
"""
404+
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
405+
"""
406+
407+
kwargs = {}
408+
fsdp_plugin = accelerator.state.fsdp_plugin
409+
410+
if fsdp_plugin is None:
411+
# FSDP not enabled in Accelerator
412+
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
413+
else:
414+
# FSDP is enabled → use plugin's strategy, or default if None
415+
kwargs["sharding_strategy"] = (
416+
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
417+
)
418+
419+
return kwargs
420+
421+
422+
def wrap_with_fsdp(
423+
model: torch.nn.Module,
424+
device: Union[str, torch.device],
425+
offload: bool = True,
426+
use_orig_params: bool = True,
427+
limit_all_gathers: bool = True,
428+
fsdp_kwargs: Optional[Dict[str, Any]] = None,
429+
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
430+
) -> FSDP:
431+
"""
432+
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
433+
434+
Args:
435+
model: Model to wrap
436+
device: Target device (e.g., accelerator.device)
437+
offload: Whether to enable CPU parameter offloading
438+
use_orig_params: Whether to use original parameters
439+
limit_all_gathers: Whether to limit all gathers
440+
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
441+
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
442+
443+
Returns:
444+
FSDP-wrapped model
445+
"""
446+
447+
if transformer_layer_cls is None:
448+
# Set the default layers if transformer_layer_cls is not provided
449+
transformer_layer_cls = type(model.model.language_model.layers[0])
450+
451+
# Add auto-wrap policy if transformer layers specified
452+
auto_wrap_policy = partial(
453+
transformer_auto_wrap_policy,
454+
transformer_layer_cls={transformer_layer_cls},
455+
)
456+
457+
config = {
458+
"device_id": device,
459+
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
460+
"use_orig_params": use_orig_params,
461+
"limit_all_gathers": limit_all_gathers,
462+
"auto_wrap_policy": auto_wrap_policy
463+
}
464+
465+
if fsdp_kwargs:
466+
config.update(fsdp_kwargs)
467+
468+
fsdp_model = FSDP(model, **config)
469+
if dist.is_initialized():
470+
dist.barrier()
471+
return fsdp_model
472+
473+
397474
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
398475
class EMAModel:
399476
"""

0 commit comments

Comments
 (0)