Skip to content

Commit bae477a

Browse files
wrapper pipelines + make style
1 parent 824fffa commit bae477a

File tree

8 files changed

+557
-42
lines changed

8 files changed

+557
-42
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
EDMEulerScheduler,
6363
FlowMatchEulerDiscreteScheduler,
6464
)
65-
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos_2_5_PredictBase
65+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBase
6666

6767

6868
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -527,7 +527,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
527527
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
528528

529529

530-
def save_pipeline_cosmos_2_5(args, transformer, vae):
530+
def save_pipeline_cosmos2_5(args, transformer, vae):
531531
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
532532
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
533533

@@ -538,7 +538,7 @@ def save_pipeline_cosmos_2_5(args, transformer, vae):
538538

539539
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
540540

541-
pipe = Cosmos_2_5_PredictBase(
541+
pipe = Cosmos2_5_PredictBase(
542542
text_encoder=text_encoder,
543543
tokenizer=tokenizer,
544544
transformer=transformer,
@@ -613,6 +613,6 @@ def get_args():
613613
assert args.tokenizer_path is not None
614614
save_pipeline_cosmos_2_0(args, transformer, vae)
615615
elif "Cosmos-2.5" in args.transformer_type:
616-
save_pipeline_cosmos_2_5(args, transformer, vae)
616+
save_pipeline_cosmos2_5(args, transformer, vae)
617617
else:
618618
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,10 @@
464464
"CogView4ControlPipeline",
465465
"CogView4Pipeline",
466466
"ConsisIDPipeline",
467-
"Cosmos_2_5_PredictBase",
467+
"Cosmos2_5_PredictBase",
468+
"Cosmos2_5_PredictImage2World",
469+
"Cosmos2_5_PredictText2World",
470+
"Cosmos2_5_PredictVideo2World",
468471
"Cosmos2TextToImagePipeline",
469472
"Cosmos2VideoToWorldPipeline",
470473
"CosmosTextToWorldPipeline",
@@ -1178,9 +1181,12 @@
11781181
CogView4ControlPipeline,
11791182
CogView4Pipeline,
11801183
ConsisIDPipeline,
1184+
Cosmos2_5_PredictBase,
1185+
Cosmos2_5_PredictImage2World,
1186+
Cosmos2_5_PredictText2World,
1187+
Cosmos2_5_PredictVideo2World,
11811188
Cosmos2TextToImagePipeline,
11821189
Cosmos2VideoToWorldPipeline,
1183-
Cosmos_2_5_PredictBase,
11841190
CosmosTextToWorldPipeline,
11851191
CosmosVideoToWorldPipeline,
11861192
CycleDiffusionPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@
165165
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
166166
_import_structure["consisid"] = ["ConsisIDPipeline"]
167167
_import_structure["cosmos"] = [
168-
"Cosmos_2_5_PredictBase",
168+
"Cosmos2_5_PredictBase",
169+
"Cosmos2_5_PredictImage2World",
170+
"Cosmos2_5_PredictText2World",
171+
"Cosmos2_5_PredictVideo2World",
169172
"Cosmos2TextToImagePipeline",
170173
"CosmosTextToWorldPipeline",
171174
"CosmosVideoToWorldPipeline",
@@ -623,9 +626,12 @@
623626
StableDiffusionXLControlNetXSPipeline,
624627
)
625628
from .cosmos import (
629+
Cosmos2_5_PredictBase,
630+
Cosmos2_5_PredictImage2World,
631+
Cosmos2_5_PredictText2World,
632+
Cosmos2_5_PredictVideo2World,
626633
Cosmos2TextToImagePipeline,
627634
Cosmos2VideoToWorldPipeline,
628-
Cosmos_2_5_PredictBase,
629635
CosmosTextToWorldPipeline,
630636
CosmosVideoToWorldPipeline,
631637
)

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25-
_import_structure["pipeline_cosmos2_5_predict"] = ["Cosmos_2_5_PredictBase", "retrieve_latents"]
25+
_import_structure["pipeline_cosmos2_5_predict"] = [
26+
"Cosmos2_5_PredictBase",
27+
"Cosmos2_5_PredictImage2World",
28+
"Cosmos2_5_PredictText2World",
29+
"Cosmos2_5_PredictVideo2World",
30+
]
2631
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
2732
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
2833
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -36,7 +41,12 @@
3641
except OptionalDependencyNotAvailable:
3742
from ...utils.dummy_torch_and_transformers_objects import *
3843
else:
39-
from .pipeline_cosmos2_5_predict import Cosmos_2_5_PredictBase, retrieve_latents
44+
from .pipeline_cosmos2_5_predict import (
45+
Cosmos2_5_PredictBase,
46+
Cosmos2_5_PredictImage2World,
47+
Cosmos2_5_PredictText2World,
48+
Cosmos2_5_PredictVideo2World,
49+
)
4050
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
4151
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
4252
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

0 commit comments

Comments
 (0)