Skip to content

Commit b530968

Browse files
Cosmos Predict2.5 Base: inference pipeline, scheduler & chkpt conversion (#12852)
* cosmos predict2.5 base: convert chkpt & pipeline - New scheduler: scheduling_flow_unipc_multistep.py - Changes to TransformerCosmos for text embeddings via crossattn_proj * scheduler cleanup * simplify inference pipeline * cleanup scheduler + tests * Basic tests for flow unipc * working b2b inference * Rename everything * Tests for pipeline present, but not working (predict2 also not working) * docstring update * wrapper pipelines + make style * remove unnecessary files * UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True * use UniPCMultistepScheduler + fix tests for pipeline * Remove FlowUniPCMultistepScheduler * UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True * num_inference_steps=36 due to bug in scheduler used by predict2.5 * Address comments * make style + make fix-copies * fix tests + remove references to old pipelines * address comments * add revision in from_pretrained call * fix tests
1 parent 55463f7 commit b530968

File tree

12 files changed

+1398
-14
lines changed

12 files changed

+1398
-14
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ output.save("output.png")
7070
- all
7171
- __call__
7272

73+
## Cosmos2_5_PredictBasePipeline
74+
75+
[[autodoc]] Cosmos2_5_PredictBasePipeline
76+
- all
77+
- __call__
78+
7379
## CosmosPipelineOutput
7480

7581
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,55 @@
1+
"""
2+
# Cosmos 2 Predict
3+
4+
Download checkpoint
5+
```bash
6+
hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+
```
8+
9+
convert checkpoint
10+
```bash
11+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+
python scripts/convert_cosmos_to_diffusers.py \
14+
--transformer_ckpt_path $transformer_ckpt_path \
15+
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+
--text_encoder_path google-t5/t5-11b \
17+
--tokenizer_path google-t5/t5-11b \
18+
--vae_type wan2.1 \
19+
--output_path converted/cosmos-p2-t2i-2b \
20+
--save_pipeline
21+
```
22+
23+
# Cosmos 2.5 Predict
24+
25+
Download checkpoint
26+
```bash
27+
hf download nvidia/Cosmos-Predict2.5-2B
28+
```
29+
30+
Convert checkpoint
31+
```bash
32+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
33+
34+
python scripts/convert_cosmos_to_diffusers.py \
35+
--transformer_type Cosmos-2.5-Predict-Base-2B \
36+
--transformer_ckpt_path $transformer_ckpt_path \
37+
--vae_type wan2.1 \
38+
--output_path converted/cosmos-p2.5-base-2b \
39+
--save_pipeline
40+
```
41+
42+
"""
43+
144
import argparse
245
import pathlib
46+
import sys
347
from typing import Any, Dict
448

549
import torch
650
from accelerate import init_empty_weights
751
from huggingface_hub import snapshot_download
8-
from transformers import T5EncoderModel, T5TokenizerFast
52+
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
953

1054
from diffusers import (
1155
AutoencoderKLCosmos,
@@ -17,7 +61,9 @@
1761
CosmosVideoToWorldPipeline,
1862
EDMEulerScheduler,
1963
FlowMatchEulerDiscreteScheduler,
64+
UniPCMultistepScheduler,
2065
)
66+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
2167

2268

2369
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -233,6 +279,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233279
"concat_padding_mask": True,
234280
"extra_pos_embed_type": None,
235281
},
282+
"Cosmos-2.5-Predict-Base-2B": {
283+
"in_channels": 16 + 1,
284+
"out_channels": 16,
285+
"num_attention_heads": 16,
286+
"attention_head_dim": 128,
287+
"num_layers": 28,
288+
"mlp_ratio": 4.0,
289+
"text_embed_dim": 1024,
290+
"adaln_lora_dim": 256,
291+
"max_size": (128, 240, 240),
292+
"patch_size": (1, 2, 2),
293+
"rope_scale": (1.0, 3.0, 3.0),
294+
"concat_padding_mask": True,
295+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
296+
"extra_pos_embed_type": None,
297+
"use_crossattn_projection": True,
298+
"crossattn_proj_in_channels": 100352,
299+
"encoder_hidden_states_channels": 1024,
300+
},
236301
}
237302

238303
VAE_KEYS_RENAME_DICT = {
@@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334399
elif "Cosmos-2.0" in transformer_type:
335400
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336401
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
402+
elif "Cosmos-2.5" in transformer_type:
403+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
404+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337405
else:
338406
assert False
339407

@@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347415
new_key = new_key.removeprefix(PREFIX_KEY)
348416
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
349417
new_key = new_key.replace(replace_key, rename_key)
418+
print(key, "->", new_key, flush=True)
350419
update_state_dict_(original_state_dict, key, new_key)
351420

352421
for key in list(original_state_dict.keys()):
@@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355424
continue
356425
handler_fn_inplace(key, original_state_dict)
357426

427+
expected_keys = set(transformer.state_dict().keys())
428+
mapped_keys = set(original_state_dict.keys())
429+
missing_keys = expected_keys - mapped_keys
430+
unexpected_keys = mapped_keys - expected_keys
431+
if missing_keys:
432+
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
433+
for k in missing_keys:
434+
print(k)
435+
sys.exit(1)
436+
if unexpected_keys:
437+
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
438+
for k in unexpected_keys:
439+
print(k)
440+
sys.exit(2)
441+
358442
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
359443
return transformer
360444

@@ -444,17 +528,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444528
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
445529

446530

531+
def save_pipeline_cosmos2_5(args, transformer, vae):
532+
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
533+
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
534+
535+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
536+
text_encoder_path, torch_dtype="auto", device_map="cpu"
537+
)
538+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
539+
540+
scheduler = UniPCMultistepScheduler(
541+
use_karras_sigmas=True,
542+
use_flow_sigmas=True,
543+
prediction_type="flow_prediction",
544+
sigma_max=200.0,
545+
sigma_min=0.01,
546+
)
547+
548+
pipe = Cosmos2_5_PredictBasePipeline(
549+
text_encoder=text_encoder,
550+
tokenizer=tokenizer,
551+
transformer=transformer,
552+
vae=vae,
553+
scheduler=scheduler,
554+
safety_checker=lambda *args, **kwargs: None,
555+
)
556+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
557+
558+
447559
def get_args():
448560
parser = argparse.ArgumentParser()
449561
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
450562
parser.add_argument(
451563
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
452564
)
453565
parser.add_argument(
454-
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
566+
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
455567
)
456-
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
457-
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
568+
parser.add_argument("--text_encoder_path", type=str, default=None)
569+
parser.add_argument("--tokenizer_path", type=str, default=None)
458570
parser.add_argument("--save_pipeline", action="store_true")
459571
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
460572
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -477,8 +589,6 @@ def get_args():
477589
if args.save_pipeline:
478590
assert args.transformer_ckpt_path is not None
479591
assert args.vae_type is not None
480-
assert args.text_encoder_path is not None
481-
assert args.tokenizer_path is not None
482592

483593
if args.transformer_ckpt_path is not None:
484594
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -490,17 +600,26 @@ def get_args():
490600
if args.vae_type is not None:
491601
if "Cosmos-1.0" in args.transformer_type:
492602
vae = convert_vae(args.vae_type)
493-
else:
603+
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
494604
vae = AutoencoderKLWan.from_pretrained(
495605
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
496606
)
607+
else:
608+
raise AssertionError(f"{args.transformer_type} not supported")
609+
497610
if not args.save_pipeline:
498611
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
499612

500613
if args.save_pipeline:
501614
if "Cosmos-1.0" in args.transformer_type:
615+
assert args.text_encoder_path is not None
616+
assert args.tokenizer_path is not None
502617
save_pipeline_cosmos_1_0(args, transformer, vae)
503618
elif "Cosmos-2.0" in args.transformer_type:
619+
assert args.text_encoder_path is not None
620+
assert args.tokenizer_path is not None
504621
save_pipeline_cosmos_2_0(args, transformer, vae)
622+
elif "Cosmos-2.5" in args.transformer_type:
623+
save_pipeline_cosmos2_5(args, transformer, vae)
505624
else:
506-
assert False
625+
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@
463463
"CogView4ControlPipeline",
464464
"CogView4Pipeline",
465465
"ConsisIDPipeline",
466+
"Cosmos2_5_PredictBasePipeline",
466467
"Cosmos2TextToImagePipeline",
467468
"Cosmos2VideoToWorldPipeline",
468469
"CosmosTextToWorldPipeline",
@@ -1175,6 +1176,7 @@
11751176
CogView4ControlPipeline,
11761177
CogView4Pipeline,
11771178
ConsisIDPipeline,
1179+
Cosmos2_5_PredictBasePipeline,
11781180
Cosmos2TextToImagePipeline,
11791181
Cosmos2VideoToWorldPipeline,
11801182
CosmosTextToWorldPipeline,

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,9 @@ def __init__(
439439
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
440440
concat_padding_mask: bool = True,
441441
extra_pos_embed_type: Optional[str] = "learnable",
442+
use_crossattn_projection: bool = False,
443+
crossattn_proj_in_channels: int = 1024,
444+
encoder_hidden_states_channels: int = 1024,
442445
) -> None:
443446
super().__init__()
444447
hidden_size = num_attention_heads * attention_head_dim
@@ -485,6 +488,12 @@ def __init__(
485488
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
486489
)
487490

491+
if self.config.use_crossattn_projection:
492+
self.crossattn_proj = nn.Sequential(
493+
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
494+
nn.GELU(),
495+
)
496+
488497
self.gradient_checkpointing = False
489498

490499
def forward(
@@ -524,6 +533,7 @@ def forward(
524533
post_patch_num_frames = num_frames // p_t
525534
post_patch_height = height // p_h
526535
post_patch_width = width // p_w
536+
527537
hidden_states = self.patch_embed(hidden_states)
528538
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
529539

@@ -546,6 +556,9 @@ def forward(
546556
else:
547557
assert False
548558

559+
if self.config.use_crossattn_projection:
560+
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
561+
549562
# 5. Transformer blocks
550563
for block in self.transformer_blocks:
551564
if torch.is_grad_enabled() and self.gradient_checkpointing:

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
166166
_import_structure["consisid"] = ["ConsisIDPipeline"]
167167
_import_structure["cosmos"] = [
168+
"Cosmos2_5_PredictBasePipeline",
168169
"Cosmos2TextToImagePipeline",
169170
"CosmosTextToWorldPipeline",
170171
"CosmosVideoToWorldPipeline",
@@ -622,6 +623,7 @@
622623
StableDiffusionXLControlNetXSPipeline,
623624
)
624625
from .cosmos import (
626+
Cosmos2_5_PredictBasePipeline,
625627
Cosmos2TextToImagePipeline,
626628
Cosmos2VideoToWorldPipeline,
627629
CosmosTextToWorldPipeline,

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["pipeline_cosmos2_5_predict"] = [
26+
"Cosmos2_5_PredictBasePipeline",
27+
]
2528
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
2629
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
2730
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -35,6 +38,9 @@
3538
except OptionalDependencyNotAvailable:
3639
from ...utils.dummy_torch_and_transformers_objects import *
3740
else:
41+
from .pipeline_cosmos2_5_predict import (
42+
Cosmos2_5_PredictBasePipeline,
43+
)
3844
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
3945
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
4046
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

0 commit comments

Comments
 (0)