|
47 | 47 |
|
48 | 48 | import numpy as np |
49 | 49 | import torch |
| 50 | +import torch.distributed as dist |
50 | 51 | import transformers |
51 | 52 | from accelerate import Accelerator |
52 | 53 | from accelerate.logging import get_logger |
@@ -722,6 +723,7 @@ def parse_args(input_args=None): |
722 | 723 | ) |
723 | 724 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
724 | 725 | parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") |
| 726 | + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") |
725 | 727 |
|
726 | 728 | if input_args is not None: |
727 | 729 | args = parser.parse_args(input_args) |
@@ -1427,12 +1429,22 @@ def load_model_hook(models, input_dir): |
1427 | 1429 | ) |
1428 | 1430 |
|
1429 | 1431 | def compute_text_embeddings(prompt, text_encoding_pipeline): |
1430 | | - with torch.no_grad(): |
1431 | | - prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
1432 | | - prompt=prompt, |
1433 | | - max_sequence_length=args.max_sequence_length, |
1434 | | - text_encoder_out_layers=args.text_encoder_out_layers, |
1435 | | - ) |
| 1432 | + if args.fsdp_text_encoder: |
| 1433 | + text_encoding_pipeline.text_encoder.eval() |
| 1434 | + with torch.no_grad(): |
| 1435 | + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
| 1436 | + prompt=prompt, |
| 1437 | + max_sequence_length=args.max_sequence_length, |
| 1438 | + text_encoder_out_layers=args.text_encoder_out_layers, |
| 1439 | + device=accelerator.device, |
| 1440 | + ) |
| 1441 | + else: |
| 1442 | + with torch.no_grad(): |
| 1443 | + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( |
| 1444 | + prompt=prompt, |
| 1445 | + max_sequence_length=args.max_sequence_length, |
| 1446 | + text_encoder_out_layers=args.text_encoder_out_layers, |
| 1447 | + ) |
1436 | 1448 | return prompt_embeds, text_ids |
1437 | 1449 |
|
1438 | 1450 | def compute_remote_text_embeddings(prompts): |
@@ -1507,6 +1519,39 @@ def _encode_single(prompt: str): |
1507 | 1519 | args.validation_prompt, text_encoding_pipeline |
1508 | 1520 | ) |
1509 | 1521 |
|
| 1522 | + # Init FSDP for text encoder |
| 1523 | + if args.fsdp_text_encoder: |
| 1524 | + import torch.distributed as dist |
| 1525 | + from torch.distributed.fsdp import ( |
| 1526 | + FullyShardedDataParallel as FSDP, |
| 1527 | + CPUOffload, |
| 1528 | + ShardingStrategy, |
| 1529 | + BackwardPrefetch, |
| 1530 | + ) |
| 1531 | + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| 1532 | + from functools import partial |
| 1533 | + |
| 1534 | + original_text_encoder = text_encoding_pipeline.text_encoder |
| 1535 | + transformer_layer = type(original_text_encoder.model.language_model.layers[0]) |
| 1536 | + auto_wrap_policy = partial( |
| 1537 | + transformer_auto_wrap_policy, |
| 1538 | + transformer_layer_cls={transformer_layer} |
| 1539 | + ) |
| 1540 | + |
| 1541 | + text_encoder_fsdp = FSDP( |
| 1542 | + original_text_encoder, |
| 1543 | + device_id=accelerator.device, |
| 1544 | + sharding_strategy=ShardingStrategy.FULL_SHARD, |
| 1545 | + cpu_offload=CPUOffload(offload_params=args.offload), |
| 1546 | + auto_wrap_policy=auto_wrap_policy, |
| 1547 | + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
| 1548 | + limit_all_gathers=True, |
| 1549 | + use_orig_params=True, |
| 1550 | + ) |
| 1551 | + |
| 1552 | + text_encoding_pipeline.text_encoder = text_encoder_fsdp |
| 1553 | + dist.barrier() |
| 1554 | + |
1510 | 1555 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), |
1511 | 1556 | # pack the statically computed variables appropriately here. This is so that we don't |
1512 | 1557 | # have to pass them to the dataloader. |
@@ -1536,6 +1581,8 @@ def _encode_single(prompt: str): |
1536 | 1581 | if train_dataset.custom_instance_prompts: |
1537 | 1582 | if args.remote_text_encoder: |
1538 | 1583 | prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) |
| 1584 | + elif args.fsdp_text_encoder: |
| 1585 | + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) |
1539 | 1586 | else: |
1540 | 1587 | with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): |
1541 | 1588 | prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) |
@@ -1836,15 +1883,40 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1836 | 1883 |
|
1837 | 1884 | # Save the lora layers |
1838 | 1885 | accelerator.wait_for_everyone() |
| 1886 | + is_fsdp = accelerator.state.fsdp_plugin is not None |
| 1887 | + |
| 1888 | + if is_fsdp: |
| 1889 | + transformer = unwrap_model(transformer) |
| 1890 | + state_dict = accelerator.get_state_dict(transformer) |
1839 | 1891 | if accelerator.is_main_process: |
1840 | 1892 | modules_to_save = {} |
1841 | | - transformer = unwrap_model(transformer) |
1842 | | - if args.bnb_quantization_config_path is None: |
1843 | | - if args.upcast_before_saving: |
1844 | | - transformer.to(torch.float32) |
1845 | | - else: |
1846 | | - transformer = transformer.to(weight_dtype) |
1847 | | - transformer_lora_layers = get_peft_model_state_dict(transformer) |
| 1893 | + if is_fsdp: |
| 1894 | + if args.bnb_quantization_config_path is None: |
| 1895 | + if args.upcast_before_saving: |
| 1896 | + state_dict = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v |
| 1897 | + for k, v in state_dict.items()} |
| 1898 | + else: |
| 1899 | + state_dict = {k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v |
| 1900 | + for k, v in state_dict.items()} |
| 1901 | + |
| 1902 | + transformer_lora_layers = get_peft_model_state_dict( |
| 1903 | + transformer, |
| 1904 | + state_dict=state_dict, |
| 1905 | + ) |
| 1906 | + transformer_lora_layers = { |
| 1907 | + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v |
| 1908 | + for k, v in transformer_lora_layers.items() |
| 1909 | + } |
| 1910 | + |
| 1911 | + else: |
| 1912 | + transformer = unwrap_model(transformer) |
| 1913 | + if args.bnb_quantization_config_path is None: |
| 1914 | + if args.upcast_before_saving: |
| 1915 | + transformer.to(torch.float32) |
| 1916 | + else: |
| 1917 | + transformer = transformer.to(weight_dtype) |
| 1918 | + transformer_lora_layers = get_peft_model_state_dict(transformer) |
| 1919 | + |
1848 | 1920 | modules_to_save["transformer"] = transformer |
1849 | 1921 |
|
1850 | 1922 | Flux2Pipeline.save_lora_weights( |
|
0 commit comments