Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
83d2b81
cosmos predict2.5 base: convert chkpt & pipeline
miguelmartin75 Dec 4, 2025
4395869
scheduler cleanup
miguelmartin75 Dec 11, 2025
e6e278e
simplify inference pipeline
miguelmartin75 Dec 13, 2025
dd6f540
cleanup scheduler + tests
miguelmartin75 Dec 15, 2025
828788e
Basic tests for flow unipc
miguelmartin75 Dec 16, 2025
899be86
working b2b inference
miguelmartin75 Dec 16, 2025
2cc2b56
Rename everything
miguelmartin75 Dec 16, 2025
04f23e8
Tests for pipeline present, but not working (predict2 also not working)
miguelmartin75 Dec 16, 2025
824fffa
docstring update
miguelmartin75 Dec 17, 2025
bae477a
wrapper pipelines + make style
miguelmartin75 Dec 17, 2025
232a816
remove unnecessary files
miguelmartin75 Dec 17, 2025
1a132f2
UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True
miguelmartin75 Dec 17, 2025
9808220
use UniPCMultistepScheduler + fix tests for pipeline
miguelmartin75 Dec 18, 2025
abba01c
Remove FlowUniPCMultistepScheduler
miguelmartin75 Dec 18, 2025
b9a35f5
UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=…
miguelmartin75 Dec 18, 2025
dd429ef
num_inference_steps=36 due to bug in scheduler used by predict2.5
miguelmartin75 Dec 18, 2025
b76f9f2
Address comments
miguelmartin75 Dec 18, 2025
735fb0e
make style + make fix-copies
miguelmartin75 Dec 18, 2025
46f7916
fix tests + remove references to old pipelines
miguelmartin75 Dec 18, 2025
5c28b08
address comments
miguelmartin75 Dec 18, 2025
8081797
add revision in from_pretrained call
miguelmartin75 Dec 18, 2025
d1dab59
fix tests
miguelmartin75 Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/cosmos.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ output.save("output.png")
- all
- __call__

## Cosmos2_5_PredictBasePipeline

[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__

## CosmosPipelineOutput

[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
Expand Down
135 changes: 127 additions & 8 deletions scripts/convert_cosmos_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,55 @@
"""
# Cosmos 2 Predict

Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2-2B-Text2Image
```

convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt

python scripts/convert_cosmos_to_diffusers.py \
--transformer_ckpt_path $transformer_ckpt_path \
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
--text_encoder_path google-t5/t5-11b \
--tokenizer_path google-t5/t5-11b \
--vae_type wan2.1 \
--output_path converted/cosmos-p2-t2i-2b \
--save_pipeline
```

# Cosmos 2.5 Predict

Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2.5-2B
```

Convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt

python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/cosmos-p2.5-base-2b \
--save_pipeline
```

"""

import argparse
import pathlib
import sys
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast

from diffusers import (
AutoencoderKLCosmos,
Expand All @@ -17,7 +61,9 @@
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline


def remove_keys_(key: str, state_dict: Dict[str, Any]):
Expand Down Expand Up @@ -233,6 +279,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.5-Predict-Base-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}

VAE_KEYS_RENAME_DICT = {
Expand Down Expand Up @@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
elif "Cosmos-2.5" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False

Expand All @@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
Expand All @@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
continue
handler_fn_inplace(key, original_state_dict)

expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
for k in missing_keys:
print(k)
sys.exit(1)
if unexpected_keys:
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
for k in unexpected_keys:
print(k)
sys.exit(2)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer

Expand Down Expand Up @@ -444,17 +528,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")


def save_pipeline_cosmos2_5(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"

text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)

pipe = Cosmos2_5_PredictBasePipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--text_encoder_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
Expand All @@ -477,8 +589,6 @@ def get_args():
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None

if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
Expand All @@ -490,17 +600,26 @@ def get_args():
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
else:
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
else:
raise AssertionError(f"{args.transformer_type} not supported")

if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
else:
assert False
raise AssertionError(f"{args.transformer_type} not supported")
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
Expand Down Expand Up @@ -1175,6 +1176,7 @@
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
Expand Down
13 changes: 13 additions & 0 deletions src/diffusers/models/transformers/transformer_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ def __init__(
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
concat_padding_mask: bool = True,
extra_pos_embed_type: Optional[str] = "learnable",
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
Expand Down Expand Up @@ -485,6 +488,12 @@ def __init__(
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
)

if self.config.use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)

self.gradient_checkpointing = False

def forward(
Expand Down Expand Up @@ -524,6 +533,7 @@ def forward(
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w

hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]

Expand All @@ -546,6 +556,9 @@ def forward(
else:
assert False

if self.config.use_crossattn_projection:
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)

# 5. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
Expand Down Expand Up @@ -622,6 +623,7 @@
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import (
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos2_5_predict"] = [
"Cosmos2_5_PredictBasePipeline",
]
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
Expand All @@ -35,6 +38,9 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos2_5_predict import (
Cosmos2_5_PredictBasePipeline,
)
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
Expand Down
Loading
Loading