Skip to content

Commit 559a7a3

Browse files
committed
Add FSDP option for Flux2
1 parent 55463f7 commit 559a7a3

File tree

2 files changed

+168
-26
lines changed

2 files changed

+168
-26
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
import numpy as np
4949
import torch
50+
import torch.distributed as dist
5051
import transformers
5152
from accelerate import Accelerator
5253
from accelerate.logging import get_logger
@@ -722,6 +723,7 @@ def parse_args(input_args=None):
722723
)
723724
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
724725
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
726+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
725727

726728
if input_args is not None:
727729
args = parser.parse_args(input_args)
@@ -1427,12 +1429,22 @@ def load_model_hook(models, input_dir):
14271429
)
14281430

14291431
def compute_text_embeddings(prompt, text_encoding_pipeline):
1430-
with torch.no_grad():
1431-
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1432-
prompt=prompt,
1433-
max_sequence_length=args.max_sequence_length,
1434-
text_encoder_out_layers=args.text_encoder_out_layers,
1435-
)
1432+
if args.fsdp_text_encoder:
1433+
text_encoding_pipeline.text_encoder.eval()
1434+
with torch.no_grad():
1435+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1436+
prompt=prompt,
1437+
max_sequence_length=args.max_sequence_length,
1438+
text_encoder_out_layers=args.text_encoder_out_layers,
1439+
device=accelerator.device,
1440+
)
1441+
else:
1442+
with torch.no_grad():
1443+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1444+
prompt=prompt,
1445+
max_sequence_length=args.max_sequence_length,
1446+
text_encoder_out_layers=args.text_encoder_out_layers,
1447+
)
14361448
return prompt_embeds, text_ids
14371449

14381450
def compute_remote_text_embeddings(prompts):
@@ -1507,6 +1519,39 @@ def _encode_single(prompt: str):
15071519
args.validation_prompt, text_encoding_pipeline
15081520
)
15091521

1522+
# Init FSDP for text encoder
1523+
if args.fsdp_text_encoder:
1524+
import torch.distributed as dist
1525+
from torch.distributed.fsdp import (
1526+
FullyShardedDataParallel as FSDP,
1527+
CPUOffload,
1528+
ShardingStrategy,
1529+
BackwardPrefetch,
1530+
)
1531+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
1532+
from functools import partial
1533+
1534+
original_text_encoder = text_encoding_pipeline.text_encoder
1535+
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
1536+
auto_wrap_policy = partial(
1537+
transformer_auto_wrap_policy,
1538+
transformer_layer_cls={transformer_layer}
1539+
)
1540+
1541+
text_encoder_fsdp = FSDP(
1542+
original_text_encoder,
1543+
device_id=accelerator.device,
1544+
sharding_strategy=ShardingStrategy.FULL_SHARD,
1545+
cpu_offload=CPUOffload(offload_params=args.offload),
1546+
auto_wrap_policy=auto_wrap_policy,
1547+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
1548+
limit_all_gathers=True,
1549+
use_orig_params=True,
1550+
)
1551+
1552+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1553+
dist.barrier()
1554+
15101555
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
15111556
# pack the statically computed variables appropriately here. This is so that we don't
15121557
# have to pass them to the dataloader.
@@ -1536,6 +1581,8 @@ def _encode_single(prompt: str):
15361581
if train_dataset.custom_instance_prompts:
15371582
if args.remote_text_encoder:
15381583
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1584+
elif args.fsdp_text_encoder:
1585+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
15391586
else:
15401587
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
15411588
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1836,15 +1883,40 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18361883

18371884
# Save the lora layers
18381885
accelerator.wait_for_everyone()
1886+
is_fsdp = accelerator.state.fsdp_plugin is not None
1887+
1888+
if is_fsdp:
1889+
transformer = unwrap_model(transformer)
1890+
state_dict = accelerator.get_state_dict(transformer)
18391891
if accelerator.is_main_process:
18401892
modules_to_save = {}
1841-
transformer = unwrap_model(transformer)
1842-
if args.bnb_quantization_config_path is None:
1843-
if args.upcast_before_saving:
1844-
transformer.to(torch.float32)
1845-
else:
1846-
transformer = transformer.to(weight_dtype)
1847-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1893+
if is_fsdp:
1894+
if args.bnb_quantization_config_path is None:
1895+
if args.upcast_before_saving:
1896+
state_dict = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v
1897+
for k, v in state_dict.items()}
1898+
else:
1899+
state_dict = {k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v
1900+
for k, v in state_dict.items()}
1901+
1902+
transformer_lora_layers = get_peft_model_state_dict(
1903+
transformer,
1904+
state_dict=state_dict,
1905+
)
1906+
transformer_lora_layers = {
1907+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1908+
for k, v in transformer_lora_layers.items()
1909+
}
1910+
1911+
else:
1912+
transformer = unwrap_model(transformer)
1913+
if args.bnb_quantization_config_path is None:
1914+
if args.upcast_before_saving:
1915+
transformer.to(torch.float32)
1916+
else:
1917+
transformer = transformer.to(weight_dtype)
1918+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1919+
18481920
modules_to_save["transformer"] = transformer
18491921

18501922
Flux2Pipeline.save_lora_weights(

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
import numpy as np
4848
import torch
49+
import torch.distributed as dist
4950
import transformers
5051
from accelerate import Accelerator
5152
from accelerate.logging import get_logger
@@ -691,6 +692,7 @@ def parse_args(input_args=None):
691692

692693
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
693694
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
695+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
694696

695697
if input_args is not None:
696698
args = parser.parse_args(input_args)
@@ -1361,12 +1363,21 @@ def load_model_hook(models, input_dir):
13611363
)
13621364

13631365
def compute_text_embeddings(prompt, text_encoding_pipeline):
1364-
with torch.no_grad():
1365-
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1366-
prompt=prompt, max_sequence_length=args.max_sequence_length
1367-
)
1368-
# prompt_embeds = prompt_embeds.to(accelerator.device)
1369-
# text_ids = text_ids.to(accelerator.device)
1366+
if args.fsdp_text_encoder:
1367+
text_encoding_pipeline.text_encoder.eval()
1368+
with torch.no_grad():
1369+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1370+
prompt=prompt,
1371+
max_sequence_length=args.max_sequence_length,
1372+
device=accelerator.device,
1373+
)
1374+
else:
1375+
with torch.no_grad():
1376+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1377+
prompt=prompt, max_sequence_length=args.max_sequence_length
1378+
)
1379+
# prompt_embeds = prompt_embeds.to(accelerator.device)
1380+
# text_ids = text_ids.to(accelerator.device)
13701381
return prompt_embeds, text_ids
13711382

13721383
def compute_remote_text_embeddings(prompts: str | list[str]):
@@ -1430,6 +1441,38 @@ def _encode_single(prompt: str):
14301441
args.validation_prompt, text_encoding_pipeline
14311442
)
14321443

1444+
# Init FSDP for text encoder
1445+
if args.fsdp_text_encoder:
1446+
from torch.distributed.fsdp import (
1447+
FullyShardedDataParallel as FSDP,
1448+
CPUOffload,
1449+
ShardingStrategy,
1450+
BackwardPrefetch,
1451+
)
1452+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
1453+
from functools import partial
1454+
1455+
original_text_encoder = text_encoding_pipeline.text_encoder
1456+
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
1457+
auto_wrap_policy = partial(
1458+
transformer_auto_wrap_policy,
1459+
transformer_layer_cls={transformer_layer}
1460+
)
1461+
1462+
text_encoder_fsdp = FSDP(
1463+
original_text_encoder,
1464+
device_id=accelerator.device,
1465+
sharding_strategy=ShardingStrategy.FULL_SHARD,
1466+
cpu_offload=CPUOffload(offload_params=args.offload),
1467+
auto_wrap_policy=auto_wrap_policy,
1468+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
1469+
limit_all_gathers=True,
1470+
use_orig_params=True,
1471+
)
1472+
1473+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1474+
dist.barrier()
1475+
14331476
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14341477
# pack the statically computed variables appropriately here. This is so that we don't
14351478
# have to pass them to the dataloader.
@@ -1461,6 +1504,8 @@ def _encode_single(prompt: str):
14611504
if train_dataset.custom_instance_prompts:
14621505
if args.remote_text_encoder:
14631506
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1507+
elif args.fsdp_text_encoder:
1508+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
14641509
else:
14651510
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
14661511
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1759,15 +1804,40 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17591804

17601805
# Save the lora layers
17611806
accelerator.wait_for_everyone()
1807+
is_fsdp = accelerator.state.fsdp_plugin is not None
1808+
1809+
if is_fsdp:
1810+
transformer = unwrap_model(transformer)
1811+
state_dict = accelerator.get_state_dict(transformer)
17621812
if accelerator.is_main_process:
17631813
modules_to_save = {}
1764-
transformer = unwrap_model(transformer)
1765-
if args.bnb_quantization_config_path is None:
1766-
if args.upcast_before_saving:
1767-
transformer.to(torch.float32)
1768-
else:
1769-
transformer = transformer.to(weight_dtype)
1770-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1814+
if is_fsdp:
1815+
if args.bnb_quantization_config_path is None:
1816+
if args.upcast_before_saving:
1817+
state_dict = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v
1818+
for k, v in state_dict.items()}
1819+
else:
1820+
state_dict = {k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v
1821+
for k, v in state_dict.items()}
1822+
1823+
transformer_lora_layers = get_peft_model_state_dict(
1824+
transformer,
1825+
state_dict=state_dict,
1826+
)
1827+
transformer_lora_layers = {
1828+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1829+
for k, v in transformer_lora_layers.items()
1830+
}
1831+
1832+
else:
1833+
transformer = unwrap_model(transformer)
1834+
if args.bnb_quantization_config_path is None:
1835+
if args.upcast_before_saving:
1836+
transformer.to(torch.float32)
1837+
else:
1838+
transformer = transformer.to(weight_dtype)
1839+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1840+
17711841
modules_to_save["transformer"] = transformer
17721842

17731843
Flux2Pipeline.save_lora_weights(

0 commit comments

Comments
 (0)