|
46 | 46 |
|
47 | 47 | import numpy as np |
48 | 48 | import torch |
| 49 | +import torch.distributed as dist |
49 | 50 | import transformers |
50 | 51 | from accelerate import Accelerator |
51 | 52 | from accelerate.logging import get_logger |
@@ -691,6 +692,7 @@ def parse_args(input_args=None): |
691 | 692 |
|
692 | 693 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
693 | 694 | parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") |
| 695 | + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") |
694 | 696 |
|
695 | 697 | if input_args is not None: |
696 | 698 | args = parser.parse_args(input_args) |
@@ -1361,12 +1363,21 @@ def load_model_hook(models, input_dir): |
1361 | 1363 | ) |
1362 | 1364 |
|
1363 | 1365 | def compute_text_embeddings(prompt, text_encoding_pipeline): |
1364 | | - with torch.no_grad(): |
1365 | | - prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
1366 | | - prompt=prompt, max_sequence_length=args.max_sequence_length |
1367 | | - ) |
1368 | | - # prompt_embeds = prompt_embeds.to(accelerator.device) |
1369 | | - # text_ids = text_ids.to(accelerator.device) |
| 1366 | + if args.fsdp_text_encoder: |
| 1367 | + text_encoding_pipeline.text_encoder.eval() |
| 1368 | + with torch.no_grad(): |
| 1369 | + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
| 1370 | + prompt=prompt, |
| 1371 | + max_sequence_length=args.max_sequence_length, |
| 1372 | + device=accelerator.device, |
| 1373 | + ) |
| 1374 | + else: |
| 1375 | + with torch.no_grad(): |
| 1376 | + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
| 1377 | + prompt=prompt, max_sequence_length=args.max_sequence_length |
| 1378 | + ) |
| 1379 | + # prompt_embeds = prompt_embeds.to(accelerator.device) |
| 1380 | + # text_ids = text_ids.to(accelerator.device) |
1370 | 1381 | return prompt_embeds, text_ids |
1371 | 1382 |
|
1372 | 1383 | def compute_remote_text_embeddings(prompts: str | list[str]): |
@@ -1430,6 +1441,41 @@ def _encode_single(prompt: str): |
1430 | 1441 | args.validation_prompt, text_encoding_pipeline |
1431 | 1442 | ) |
1432 | 1443 |
|
| 1444 | + # Init FSDP for text encoder |
| 1445 | + if args.fsdp_text_encoder: |
| 1446 | + from functools import partial |
| 1447 | + |
| 1448 | + from torch.distributed.fsdp import ( |
| 1449 | + BackwardPrefetch, |
| 1450 | + CPUOffload, |
| 1451 | + ShardingStrategy, |
| 1452 | + ) |
| 1453 | + from torch.distributed.fsdp import ( |
| 1454 | + FullyShardedDataParallel as FSDP, |
| 1455 | + ) |
| 1456 | + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| 1457 | + |
| 1458 | + original_text_encoder = text_encoding_pipeline.text_encoder |
| 1459 | + transformer_layer = type(original_text_encoder.model.language_model.layers[0]) |
| 1460 | + auto_wrap_policy = partial( |
| 1461 | + transformer_auto_wrap_policy, |
| 1462 | + transformer_layer_cls={transformer_layer} |
| 1463 | + ) |
| 1464 | + |
| 1465 | + text_encoder_fsdp = FSDP( |
| 1466 | + original_text_encoder, |
| 1467 | + device_id=accelerator.device, |
| 1468 | + sharding_strategy=ShardingStrategy.FULL_SHARD, |
| 1469 | + cpu_offload=CPUOffload(offload_params=args.offload), |
| 1470 | + auto_wrap_policy=auto_wrap_policy, |
| 1471 | + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
| 1472 | + limit_all_gathers=True, |
| 1473 | + use_orig_params=True, |
| 1474 | + ) |
| 1475 | + |
| 1476 | + text_encoding_pipeline.text_encoder = text_encoder_fsdp |
| 1477 | + dist.barrier() |
| 1478 | + |
1433 | 1479 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), |
1434 | 1480 | # pack the statically computed variables appropriately here. This is so that we don't |
1435 | 1481 | # have to pass them to the dataloader. |
@@ -1461,6 +1507,8 @@ def _encode_single(prompt: str): |
1461 | 1507 | if train_dataset.custom_instance_prompts: |
1462 | 1508 | if args.remote_text_encoder: |
1463 | 1509 | prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) |
| 1510 | + elif args.fsdp_text_encoder: |
| 1511 | + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) |
1464 | 1512 | else: |
1465 | 1513 | with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): |
1466 | 1514 | prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) |
@@ -1759,15 +1807,42 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1759 | 1807 |
|
1760 | 1808 | # Save the lora layers |
1761 | 1809 | accelerator.wait_for_everyone() |
| 1810 | + is_fsdp = accelerator.state.fsdp_plugin is not None |
| 1811 | + |
| 1812 | + if is_fsdp: |
| 1813 | + transformer = unwrap_model(transformer) |
| 1814 | + state_dict = accelerator.get_state_dict(transformer) |
1762 | 1815 | if accelerator.is_main_process: |
1763 | 1816 | modules_to_save = {} |
1764 | | - transformer = unwrap_model(transformer) |
1765 | | - if args.bnb_quantization_config_path is None: |
1766 | | - if args.upcast_before_saving: |
1767 | | - transformer.to(torch.float32) |
1768 | | - else: |
1769 | | - transformer = transformer.to(weight_dtype) |
1770 | | - transformer_lora_layers = get_peft_model_state_dict(transformer) |
| 1817 | + if is_fsdp: |
| 1818 | + if args.bnb_quantization_config_path is None: |
| 1819 | + if args.upcast_before_saving: |
| 1820 | + state_dict = { |
| 1821 | + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() |
| 1822 | + } |
| 1823 | + else: |
| 1824 | + state_dict = { |
| 1825 | + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() |
| 1826 | + } |
| 1827 | + |
| 1828 | + transformer_lora_layers = get_peft_model_state_dict( |
| 1829 | + transformer, |
| 1830 | + state_dict=state_dict, |
| 1831 | + ) |
| 1832 | + transformer_lora_layers = { |
| 1833 | + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v |
| 1834 | + for k, v in transformer_lora_layers.items() |
| 1835 | + } |
| 1836 | + |
| 1837 | + else: |
| 1838 | + transformer = unwrap_model(transformer) |
| 1839 | + if args.bnb_quantization_config_path is None: |
| 1840 | + if args.upcast_before_saving: |
| 1841 | + transformer.to(torch.float32) |
| 1842 | + else: |
| 1843 | + transformer = transformer.to(weight_dtype) |
| 1844 | + transformer_lora_layers = get_peft_model_state_dict(transformer) |
| 1845 | + |
1771 | 1846 | modules_to_save["transformer"] = transformer |
1772 | 1847 |
|
1773 | 1848 | Flux2Pipeline.save_lora_weights( |
|
0 commit comments