Skip to content

Commit 4133c68

Browse files
Address comments
1 parent 9010c93 commit 4133c68

File tree

3 files changed

+10
-18
lines changed

3 files changed

+10
-18
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ output.save("output.png")
7070
- all
7171
- __call__
7272

73+
## Cosmos2_5_PredictBasePipeline
74+
75+
[[autodoc]] Cosmos2_5_PredictBasePipeline
76+
- all
77+
- __call__
78+
7379
## CosmosPipelineOutput
7480

7581
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
FlowMatchEulerDiscreteScheduler,
6464
UniPCMultistepScheduler,
6565
)
66-
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBase
66+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
6767

6868

6969
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -545,7 +545,7 @@ def save_pipeline_cosmos2_5(args, transformer, vae):
545545
sigma_min=0.01,
546546
)
547547

548-
pipe = Cosmos2_5_PredictBase(
548+
pipe = Cosmos2_5_PredictBasePipeline(
549549
text_encoder=text_encoder,
550550
tokenizer=tokenizer,
551551
transformer=transformer,

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -233,20 +233,6 @@ def __init__(
233233
if self.latents_mean is None or self.latents_std is None:
234234
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
235235

236-
237-
@property
238-
def _execution_device(self):
239-
device = super()._execution_device
240-
if isinstance(device, torch.device) and device.type == "cpu":
241-
for module_name in ("transformer", "text_encoder", "vae"):
242-
module = getattr(self, module_name, None)
243-
if module is None or not isinstance(module, torch.nn.Module):
244-
continue
245-
module_device = getattr(module, "device", None)
246-
if isinstance(module_device, torch.device) and module_device.type != "cpu":
247-
return module_device
248-
return device
249-
250236
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_prompt_embeds
251237
def _get_prompt_embeds(
252238
self,
@@ -398,6 +384,8 @@ def encode_prompt(
398384

399385
return prompt_embeds, negative_prompt_embeds
400386

387+
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
388+
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
401389
def prepare_latents(
402390
self,
403391
video: Optional[torch.Tensor],
@@ -458,8 +446,6 @@ def prepare_latents(
458446

459447
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
460448

461-
if self.latents_mean is None or self.latents_std is None:
462-
raise ValueError("VAE configuration must define `latents_mean` and `latents_std`.")
463449
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
464450
latents_std = self.latents_std.to(device=device, dtype=dtype)
465451
cond_latents = (cond_latents - latents_mean) / latents_std

0 commit comments

Comments
 (0)