Skip to content

Commit 343b12a

Browse files
committed
Add FSDP option for Flux2
1 parent 55463f7 commit 343b12a

File tree

2 files changed

+174
-26
lines changed

2 files changed

+174
-26
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
import numpy as np
4949
import torch
50+
import torch.distributed as dist
5051
import transformers
5152
from accelerate import Accelerator
5253
from accelerate.logging import get_logger
@@ -722,6 +723,7 @@ def parse_args(input_args=None):
722723
)
723724
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
724725
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
726+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
725727

726728
if input_args is not None:
727729
args = parser.parse_args(input_args)
@@ -1427,12 +1429,22 @@ def load_model_hook(models, input_dir):
14271429
)
14281430

14291431
def compute_text_embeddings(prompt, text_encoding_pipeline):
1430-
with torch.no_grad():
1431-
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1432-
prompt=prompt,
1433-
max_sequence_length=args.max_sequence_length,
1434-
text_encoder_out_layers=args.text_encoder_out_layers,
1435-
)
1432+
if args.fsdp_text_encoder:
1433+
text_encoding_pipeline.text_encoder.eval()
1434+
with torch.no_grad():
1435+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1436+
prompt=prompt,
1437+
max_sequence_length=args.max_sequence_length,
1438+
text_encoder_out_layers=args.text_encoder_out_layers,
1439+
device=accelerator.device,
1440+
)
1441+
else:
1442+
with torch.no_grad():
1443+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1444+
prompt=prompt,
1445+
max_sequence_length=args.max_sequence_length,
1446+
text_encoder_out_layers=args.text_encoder_out_layers,
1447+
)
14361448
return prompt_embeds, text_ids
14371449

14381450
def compute_remote_text_embeddings(prompts):
@@ -1507,6 +1519,38 @@ def _encode_single(prompt: str):
15071519
args.validation_prompt, text_encoding_pipeline
15081520
)
15091521

1522+
# Init FSDP for text encoder
1523+
if args.fsdp_text_encoder:
1524+
from functools import partial
1525+
1526+
from torch.distributed.fsdp import (
1527+
BackwardPrefetch,
1528+
CPUOffload,
1529+
ShardingStrategy,
1530+
)
1531+
from torch.distributed.fsdp import (
1532+
FullyShardedDataParallel as FSDP,
1533+
)
1534+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
1535+
1536+
original_text_encoder = text_encoding_pipeline.text_encoder
1537+
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
1538+
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer})
1539+
1540+
text_encoder_fsdp = FSDP(
1541+
original_text_encoder,
1542+
device_id=accelerator.device,
1543+
sharding_strategy=ShardingStrategy.FULL_SHARD,
1544+
cpu_offload=CPUOffload(offload_params=args.offload),
1545+
auto_wrap_policy=auto_wrap_policy,
1546+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
1547+
limit_all_gathers=True,
1548+
use_orig_params=True,
1549+
)
1550+
1551+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1552+
dist.barrier()
1553+
15101554
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
15111555
# pack the statically computed variables appropriately here. This is so that we don't
15121556
# have to pass them to the dataloader.
@@ -1536,6 +1580,8 @@ def _encode_single(prompt: str):
15361580
if train_dataset.custom_instance_prompts:
15371581
if args.remote_text_encoder:
15381582
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1583+
elif args.fsdp_text_encoder:
1584+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
15391585
else:
15401586
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
15411587
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1836,15 +1882,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18361882

18371883
# Save the lora layers
18381884
accelerator.wait_for_everyone()
1885+
is_fsdp = accelerator.state.fsdp_plugin is not None
1886+
1887+
if is_fsdp:
1888+
transformer = unwrap_model(transformer)
1889+
state_dict = accelerator.get_state_dict(transformer)
18391890
if accelerator.is_main_process:
18401891
modules_to_save = {}
1841-
transformer = unwrap_model(transformer)
1842-
if args.bnb_quantization_config_path is None:
1843-
if args.upcast_before_saving:
1844-
transformer.to(torch.float32)
1845-
else:
1846-
transformer = transformer.to(weight_dtype)
1847-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1892+
if is_fsdp:
1893+
if args.bnb_quantization_config_path is None:
1894+
if args.upcast_before_saving:
1895+
state_dict = {
1896+
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1897+
}
1898+
else:
1899+
state_dict = {
1900+
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1901+
}
1902+
1903+
transformer_lora_layers = get_peft_model_state_dict(
1904+
transformer,
1905+
state_dict=state_dict,
1906+
)
1907+
transformer_lora_layers = {
1908+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1909+
for k, v in transformer_lora_layers.items()
1910+
}
1911+
1912+
else:
1913+
transformer = unwrap_model(transformer)
1914+
if args.bnb_quantization_config_path is None:
1915+
if args.upcast_before_saving:
1916+
transformer.to(torch.float32)
1917+
else:
1918+
transformer = transformer.to(weight_dtype)
1919+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1920+
18481921
modules_to_save["transformer"] = transformer
18491922

18501923
Flux2Pipeline.save_lora_weights(

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
import numpy as np
4848
import torch
49+
import torch.distributed as dist
4950
import transformers
5051
from accelerate import Accelerator
5152
from accelerate.logging import get_logger
@@ -691,6 +692,7 @@ def parse_args(input_args=None):
691692

692693
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
693694
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
695+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
694696

695697
if input_args is not None:
696698
args = parser.parse_args(input_args)
@@ -1361,12 +1363,21 @@ def load_model_hook(models, input_dir):
13611363
)
13621364

13631365
def compute_text_embeddings(prompt, text_encoding_pipeline):
1364-
with torch.no_grad():
1365-
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1366-
prompt=prompt, max_sequence_length=args.max_sequence_length
1367-
)
1368-
# prompt_embeds = prompt_embeds.to(accelerator.device)
1369-
# text_ids = text_ids.to(accelerator.device)
1366+
if args.fsdp_text_encoder:
1367+
text_encoding_pipeline.text_encoder.eval()
1368+
with torch.no_grad():
1369+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1370+
prompt=prompt,
1371+
max_sequence_length=args.max_sequence_length,
1372+
device=accelerator.device,
1373+
)
1374+
else:
1375+
with torch.no_grad():
1376+
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1377+
prompt=prompt, max_sequence_length=args.max_sequence_length
1378+
)
1379+
# prompt_embeds = prompt_embeds.to(accelerator.device)
1380+
# text_ids = text_ids.to(accelerator.device)
13701381
return prompt_embeds, text_ids
13711382

13721383
def compute_remote_text_embeddings(prompts: str | list[str]):
@@ -1430,6 +1441,41 @@ def _encode_single(prompt: str):
14301441
args.validation_prompt, text_encoding_pipeline
14311442
)
14321443

1444+
# Init FSDP for text encoder
1445+
if args.fsdp_text_encoder:
1446+
from functools import partial
1447+
1448+
from torch.distributed.fsdp import (
1449+
BackwardPrefetch,
1450+
CPUOffload,
1451+
ShardingStrategy,
1452+
)
1453+
from torch.distributed.fsdp import (
1454+
FullyShardedDataParallel as FSDP,
1455+
)
1456+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
1457+
1458+
original_text_encoder = text_encoding_pipeline.text_encoder
1459+
transformer_layer = type(original_text_encoder.model.language_model.layers[0])
1460+
auto_wrap_policy = partial(
1461+
transformer_auto_wrap_policy,
1462+
transformer_layer_cls={transformer_layer}
1463+
)
1464+
1465+
text_encoder_fsdp = FSDP(
1466+
original_text_encoder,
1467+
device_id=accelerator.device,
1468+
sharding_strategy=ShardingStrategy.FULL_SHARD,
1469+
cpu_offload=CPUOffload(offload_params=args.offload),
1470+
auto_wrap_policy=auto_wrap_policy,
1471+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
1472+
limit_all_gathers=True,
1473+
use_orig_params=True,
1474+
)
1475+
1476+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1477+
dist.barrier()
1478+
14331479
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14341480
# pack the statically computed variables appropriately here. This is so that we don't
14351481
# have to pass them to the dataloader.
@@ -1461,6 +1507,8 @@ def _encode_single(prompt: str):
14611507
if train_dataset.custom_instance_prompts:
14621508
if args.remote_text_encoder:
14631509
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1510+
elif args.fsdp_text_encoder:
1511+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
14641512
else:
14651513
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
14661514
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1759,15 +1807,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17591807

17601808
# Save the lora layers
17611809
accelerator.wait_for_everyone()
1810+
is_fsdp = accelerator.state.fsdp_plugin is not None
1811+
1812+
if is_fsdp:
1813+
transformer = unwrap_model(transformer)
1814+
state_dict = accelerator.get_state_dict(transformer)
17621815
if accelerator.is_main_process:
17631816
modules_to_save = {}
1764-
transformer = unwrap_model(transformer)
1765-
if args.bnb_quantization_config_path is None:
1766-
if args.upcast_before_saving:
1767-
transformer.to(torch.float32)
1768-
else:
1769-
transformer = transformer.to(weight_dtype)
1770-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1817+
if is_fsdp:
1818+
if args.bnb_quantization_config_path is None:
1819+
if args.upcast_before_saving:
1820+
state_dict = {
1821+
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1822+
}
1823+
else:
1824+
state_dict = {
1825+
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1826+
}
1827+
1828+
transformer_lora_layers = get_peft_model_state_dict(
1829+
transformer,
1830+
state_dict=state_dict,
1831+
)
1832+
transformer_lora_layers = {
1833+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1834+
for k, v in transformer_lora_layers.items()
1835+
}
1836+
1837+
else:
1838+
transformer = unwrap_model(transformer)
1839+
if args.bnb_quantization_config_path is None:
1840+
if args.upcast_before_saving:
1841+
transformer.to(torch.float32)
1842+
else:
1843+
transformer = transformer.to(weight_dtype)
1844+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1845+
17711846
modules_to_save["transformer"] = transformer
17721847

17731848
Flux2Pipeline.save_lora_weights(

0 commit comments

Comments
 (0)