1+ """
2+ # Cosmos 2 Predict
3+
4+ Download checkpoint
5+ ```bash
6+ hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+ ```
8+
9+ convert checkpoint
10+ ```bash
11+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+ python scripts/convert_cosmos_to_diffusers.py \
14+ --transformer_ckpt_path $transformer_ckpt_path \
15+ --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+ --text_encoder_path google-t5/t5-11b \
17+ --tokenizer_path google-t5/t5-11b \
18+ --vae_type wan2.1 \
19+ --output_path converted/cosmos-p2-t2i-2b \
20+ --save_pipeline
21+ ```
22+
23+ # Cosmos 2.5 Predict
24+
25+ Download checkpoint
26+ ```bash
27+ hf download nvidia/Cosmos-Predict2.5-2B
28+ ```
29+
30+ Convert checkpoint
31+ ```bash
32+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
33+
34+ python scripts/convert_cosmos_to_diffusers.py \
35+ --transformer_type Cosmos-2.5-Predict-Base-2B \
36+ --transformer_ckpt_path $transformer_ckpt_path \
37+ --vae_type wan2.1 \
38+ --output_path converted/cosmos-p2.5-base-2b \
39+ --save_pipeline
40+ ```
41+
42+ """
43+
144import argparse
245import pathlib
46+ import sys
347from typing import Any , Dict
448
549import torch
650from accelerate import init_empty_weights
751from huggingface_hub import snapshot_download
8- from transformers import T5EncoderModel , T5TokenizerFast
52+ from transformers import AutoTokenizer , Qwen2_5_VLForConditionalGeneration , T5EncoderModel , T5TokenizerFast
953
1054from diffusers import (
1155 AutoencoderKLCosmos ,
1761 CosmosVideoToWorldPipeline ,
1862 EDMEulerScheduler ,
1963 FlowMatchEulerDiscreteScheduler ,
64+ UniPCMultistepScheduler ,
2065)
66+ from diffusers .pipelines .cosmos .pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
2167
2268
2369def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -233,6 +279,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233279 "concat_padding_mask" : True ,
234280 "extra_pos_embed_type" : None ,
235281 },
282+ "Cosmos-2.5-Predict-Base-2B" : {
283+ "in_channels" : 16 + 1 ,
284+ "out_channels" : 16 ,
285+ "num_attention_heads" : 16 ,
286+ "attention_head_dim" : 128 ,
287+ "num_layers" : 28 ,
288+ "mlp_ratio" : 4.0 ,
289+ "text_embed_dim" : 1024 ,
290+ "adaln_lora_dim" : 256 ,
291+ "max_size" : (128 , 240 , 240 ),
292+ "patch_size" : (1 , 2 , 2 ),
293+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
294+ "concat_padding_mask" : True ,
295+ # NOTE: source config has pos_emb_learnable: 'True' - but params are missing
296+ "extra_pos_embed_type" : None ,
297+ "use_crossattn_projection" : True ,
298+ "crossattn_proj_in_channels" : 100352 ,
299+ "encoder_hidden_states_channels" : 1024 ,
300+ },
236301}
237302
238303VAE_KEYS_RENAME_DICT = {
@@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334399 elif "Cosmos-2.0" in transformer_type :
335400 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336401 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
402+ elif "Cosmos-2.5" in transformer_type :
403+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
404+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337405 else :
338406 assert False
339407
@@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347415 new_key = new_key .removeprefix (PREFIX_KEY )
348416 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
349417 new_key = new_key .replace (replace_key , rename_key )
418+ print (key , "->" , new_key , flush = True )
350419 update_state_dict_ (original_state_dict , key , new_key )
351420
352421 for key in list (original_state_dict .keys ()):
@@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355424 continue
356425 handler_fn_inplace (key , original_state_dict )
357426
427+ expected_keys = set (transformer .state_dict ().keys ())
428+ mapped_keys = set (original_state_dict .keys ())
429+ missing_keys = expected_keys - mapped_keys
430+ unexpected_keys = mapped_keys - expected_keys
431+ if missing_keys :
432+ print (f"ERROR: missing keys ({ len (missing_keys )} from state_dict:" , flush = True , file = sys .stderr )
433+ for k in missing_keys :
434+ print (k )
435+ sys .exit (1 )
436+ if unexpected_keys :
437+ print (f"ERROR: unexpected keys ({ len (unexpected_keys )} ) from state_dict:" , flush = True , file = sys .stderr )
438+ for k in unexpected_keys :
439+ print (k )
440+ sys .exit (2 )
441+
358442 transformer .load_state_dict (original_state_dict , strict = True , assign = True )
359443 return transformer
360444
@@ -444,17 +528,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444528 pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
445529
446530
531+ def save_pipeline_cosmos2_5 (args , transformer , vae ):
532+ text_encoder_path = args .text_encoder_path or "nvidia/Cosmos-Reason1-7B"
533+ tokenizer_path = args .tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
534+
535+ text_encoder = Qwen2_5_VLForConditionalGeneration .from_pretrained (
536+ text_encoder_path , torch_dtype = "auto" , device_map = "cpu"
537+ )
538+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
539+
540+ scheduler = UniPCMultistepScheduler (
541+ use_karras_sigmas = True ,
542+ use_flow_sigmas = True ,
543+ prediction_type = "flow_prediction" ,
544+ sigma_max = 200.0 ,
545+ sigma_min = 0.01 ,
546+ )
547+
548+ pipe = Cosmos2_5_PredictBasePipeline (
549+ text_encoder = text_encoder ,
550+ tokenizer = tokenizer ,
551+ transformer = transformer ,
552+ vae = vae ,
553+ scheduler = scheduler ,
554+ safety_checker = lambda * args , ** kwargs : None ,
555+ )
556+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
557+
558+
447559def get_args ():
448560 parser = argparse .ArgumentParser ()
449561 parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
450562 parser .add_argument (
451563 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
452564 )
453565 parser .add_argument (
454- "--vae_type" , type = str , default = None , choices = ["none " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
566+ "--vae_type" , type = str , default = "wan2.1" , choices = ["wan2.1 " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
455567 )
456- parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
457- parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
568+ parser .add_argument ("--text_encoder_path" , type = str , default = None )
569+ parser .add_argument ("--tokenizer_path" , type = str , default = None )
458570 parser .add_argument ("--save_pipeline" , action = "store_true" )
459571 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
460572 parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
@@ -477,8 +589,6 @@ def get_args():
477589 if args .save_pipeline :
478590 assert args .transformer_ckpt_path is not None
479591 assert args .vae_type is not None
480- assert args .text_encoder_path is not None
481- assert args .tokenizer_path is not None
482592
483593 if args .transformer_ckpt_path is not None :
484594 weights_only = "Cosmos-1.0" in args .transformer_type
@@ -490,17 +600,26 @@ def get_args():
490600 if args .vae_type is not None :
491601 if "Cosmos-1.0" in args .transformer_type :
492602 vae = convert_vae (args .vae_type )
493- else :
603+ elif "Cosmos-2.0" in args . transformer_type or "Cosmos-2.5" in args . transformer_type :
494604 vae = AutoencoderKLWan .from_pretrained (
495605 "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
496606 )
607+ else :
608+ raise AssertionError (f"{ args .transformer_type } not supported" )
609+
497610 if not args .save_pipeline :
498611 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
499612
500613 if args .save_pipeline :
501614 if "Cosmos-1.0" in args .transformer_type :
615+ assert args .text_encoder_path is not None
616+ assert args .tokenizer_path is not None
502617 save_pipeline_cosmos_1_0 (args , transformer , vae )
503618 elif "Cosmos-2.0" in args .transformer_type :
619+ assert args .text_encoder_path is not None
620+ assert args .tokenizer_path is not None
504621 save_pipeline_cosmos_2_0 (args , transformer , vae )
622+ elif "Cosmos-2.5" in args .transformer_type :
623+ save_pipeline_cosmos2_5 (args , transformer , vae )
505624 else :
506- assert False
625+ raise AssertionError ( f" { args . transformer_type } not supported" )
0 commit comments