Skip to content

Commit 2a75ff3

Browse files
committed
add autopipeline for text2video task
1 parent a748a83 commit 2a75ff3

File tree

5 files changed

+285
-1
lines changed

5 files changed

+285
-1
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@
306306
"AutoPipelineForImage2Image",
307307
"AutoPipelineForInpainting",
308308
"AutoPipelineForText2Image",
309+
"AutoPipelineForText2Video",
309310
"ConsistencyModelPipeline",
310311
"DanceDiffusionPipeline",
311312
"DDIMPipeline",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForImage2Image",
4747
"AutoPipelineForInpainting",
4848
"AutoPipelineForText2Image",
49+
"AutoPipelineForText2Video",
4950
]
5051
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5152
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 252 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@
117117
StableDiffusionXLInpaintPipeline,
118118
StableDiffusionXLPipeline,
119119
)
120-
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
120+
121+
from .wan import WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
121122
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122123
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
123124

@@ -221,6 +222,8 @@
221222
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
222223
[
223224
("wan", WanPipeline),
225+
("wan-animate", WanAnimatePipeline),
226+
("wan-vace", WanVACEPipeline),
224227
]
225228
)
226229

@@ -1206,3 +1209,251 @@ def from_pipe(cls, pipeline, **kwargs):
12061209
model.register_to_config(**unused_original_config)
12071210

12081211
return model
1212+
1213+
class AutoPipelineForText2Video(ConfigMixin):
1214+
r"""
1215+
1216+
[`AutoPipelineForText2Video`] is a generic pipeline class that instantiates an text-to-video pipeline class. The
1217+
specific underlying pipeline class is automatically selected from either the
1218+
[`~AutoPipelineForText2Video.from_pretrained`] or [`~AutoPipelineForText2Video.from_pipe`] methods.
1219+
1220+
This class cannot be instantiated using `__init__()` (throws an error).
1221+
1222+
Class attributes:
1223+
1224+
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
1225+
diffusion pipeline's components.
1226+
1227+
"""
1228+
1229+
config_name = "model_index.json"
1230+
1231+
def __init__(self, *args, **kwargs):
1232+
raise EnvironmentError(
1233+
f"{self.__class__.__name__} is designed to be instantiated "
1234+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1235+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1236+
)
1237+
1238+
@classmethod
1239+
@validate_hf_hub_args
1240+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1241+
r"""
1242+
Instantiates a text-to-video Pytorch diffusion pipeline from pretrained pipeline weight.
1243+
1244+
The from_pretrained() method takes care of returning the correct pipeline class instance by:
1245+
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
1246+
config object
1247+
2. Find the text-to-video pipeline linked to the pipeline class using pattern matching on pipeline class name.
1248+
1249+
1250+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1251+
1252+
If you get the error message below, you need to finetune the weights for your downstream task:
1253+
1254+
```
1255+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1256+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1257+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1258+
```
1259+
1260+
Parameters:
1261+
pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
1262+
Can be either:
1263+
1264+
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
1265+
hosted on the Hub.
1266+
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
1267+
saved using
1268+
[`~DiffusionPipeline.save_pretrained`].
1269+
torch_dtype (`str` or `torch.dtype`, *optional*):
1270+
Override the default `torch.dtype` and load the model with another dtype.
1271+
force_download (`bool`, *optional*, defaults to `False`):
1272+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1273+
cached versions if they exist.
1274+
cache_dir (`Union[str, os.PathLike]`, *optional*):
1275+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1276+
is not used.
1277+
1278+
proxies (`Dict[str, str]`, *optional*):
1279+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1280+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1281+
output_loading_info(`bool`, *optional*, defaults to `False`):
1282+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1283+
local_files_only (`bool`, *optional*, defaults to `False`):
1284+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
1285+
won't be downloaded from the Hub.
1286+
token (`str` or *bool*, *optional*):
1287+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1288+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
1289+
revision (`str`, *optional*, defaults to `"main"`):
1290+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1291+
allowed by Git.
1292+
custom_revision (`str`, *optional*, defaults to `"main"`):
1293+
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
1294+
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
1295+
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1296+
mirror (`str`, *optional*):
1297+
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
1298+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1299+
information.
1300+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1301+
A map that specifies where each submodule should go. It doesn’t need to be defined for each
1302+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1303+
same device.
1304+
1305+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1306+
more information about each option see [designing a device
1307+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1308+
max_memory (`Dict`, *optional*):
1309+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1310+
each GPU and the available CPU RAM if unset.
1311+
offload_folder (`str` or `os.PathLike`, *optional*):
1312+
The path to offload weights if device_map contains the value `"disk"`.
1313+
offload_state_dict (`bool`, *optional*):
1314+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1315+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1316+
when there is some disk offload.
1317+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1318+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1319+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1320+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1321+
argument to `True` will raise an error.
1322+
use_safetensors (`bool`, *optional*, defaults to `None`):
1323+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1324+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1325+
weights. If set to `False`, safetensors weights are not loaded.
1326+
kwargs (remaining dictionary of keyword arguments, *optional*):
1327+
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
1328+
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
1329+
below for more information.
1330+
variant (`str`, *optional*):
1331+
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
1332+
loading `from_flax`.
1333+
1334+
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
1335+
with `hf > auth login`.
1336+
1337+
Examples:
1338+
1339+
```py
1340+
>>> from diffusers import AutoPipelineForInpainting
1341+
1342+
>>> pipeline = AutoPipelineForText2Video.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
1343+
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
1344+
```
1345+
"""
1346+
1347+
cache_dir = kwargs.pop("cache_dir", None)
1348+
force_download = kwargs.pop("force_download", False)
1349+
proxies = kwargs.pop("proxies", None)
1350+
token = kwargs.pop("token", None)
1351+
local_files_only = kwargs.pop("local_files_only", False)
1352+
revision = kwargs.pop("revision", None)
1353+
1354+
load_config_kwargs = {
1355+
"cache_dir": cache_dir,
1356+
"force_download": force_download,
1357+
"proxies": proxies,
1358+
"token": token,
1359+
"local_files_only": local_files_only,
1360+
"revision": revision,
1361+
}
1362+
1363+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1364+
orig_class_name = config["_class_name"]
1365+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name)
1366+
kwargs = {**load_config_kwargs, **kwargs}
1367+
return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs)
1368+
1369+
@classmethod
1370+
def from_pipe(cls, pipeline, **kwargs):
1371+
r"""
1372+
Instantiates a text-to-video Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
1373+
1374+
The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-video
1375+
pipeline linked to the pipeline class using pattern matching on pipeline class name.
1376+
1377+
All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating
1378+
additional memory.
1379+
1380+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1381+
1382+
Parameters:
1383+
pipeline (`DiffusionPipeline`):
1384+
an instantiated `DiffusionPipeline` object
1385+
1386+
Examples:
1387+
1388+
```py
1389+
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForText2Video
1390+
1391+
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
1392+
... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False
1393+
... )
1394+
1395+
>>> pipe_text2video = AutoPipelineForText2Image.from_pipe(pipe_t2i)
1396+
>>> output = pipe_text2video(prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames).frames[0]
1397+
```
1398+
"""
1399+
original_config = dict(pipeline.config)
1400+
original_cls_name = pipeline.__class__.__name__
1401+
1402+
# derive the pipeline class to instantiate
1403+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, original_cls_name)
1404+
1405+
# define expected module and optional kwargs given the pipeline signature
1406+
expected_modules, optional_kwargs = text_to_video_cls._get_signature_keys(text_to_video_cls)
1407+
1408+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1409+
1410+
# allow users pass modules in `kwargs` to override the original pipeline's components
1411+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1412+
original_class_obj = {
1413+
k: pipeline.components[k]
1414+
for k, v in pipeline.components.items()
1415+
if k in expected_modules and k not in passed_class_obj
1416+
}
1417+
1418+
# allow users pass optional kwargs to override the original pipelines config attribute
1419+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1420+
original_pipe_kwargs = {
1421+
k: original_config[k]
1422+
for k, v in original_config.items()
1423+
if k in optional_kwargs and k not in passed_pipe_kwargs
1424+
}
1425+
1426+
# config that were not expected by original pipeline is stored as private attribute
1427+
# we will pass them as optional arguments if they can be accepted by the pipeline
1428+
additional_pipe_kwargs = [
1429+
k[1:]
1430+
for k in original_config.keys()
1431+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1432+
]
1433+
for k in additional_pipe_kwargs:
1434+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1435+
1436+
text_to_video_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
1437+
1438+
# store unused config as private attribute
1439+
unused_original_config = {
1440+
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
1441+
for k, v in original_config.items()
1442+
if k not in text_to_video_kwargs
1443+
}
1444+
1445+
missing_modules = (
1446+
set(expected_modules) - set(text_to_video_cls._optional_components) - set(text_to_video_kwargs.keys())
1447+
)
1448+
1449+
if len(missing_modules) > 0:
1450+
raise ValueError(
1451+
f"Pipeline {text_to_video_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed")
1452+
1453+
model = text_to_video_cls(**text_to_video_kwargs)
1454+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1455+
model.register_to_config(**unused_original_config)
1456+
1457+
return model
1458+
1459+

test_autopipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
from diffusers import AutoPipelineForText2Video
3+
from diffusers.utils import export_to_video
4+
5+
wan_list = [
6+
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
7+
"Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers",
8+
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
9+
]
10+
11+
pipe = AutoPipelineForText2Video.from_pretrained(
12+
wan_list[0],
13+
torch_dtype=torch.float16,
14+
)
15+
16+
print("Successfully loaded pipeline\n")
17+
18+
prompt = "A cat walks on the grass, realistic"
19+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
20+
21+
output = pipe(
22+
prompt=prompt,
23+
negative_prompt=negative_prompt,
24+
height=256,
25+
width=256,
26+
num_frames=10,
27+
guidance_scale=5.0
28+
).frames[0]
29+
export_to_video(output, "output.mp4", fps=15)

tests/pipelines/test_pipelines_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
AutoPipelineForImage2Image,
2828
AutoPipelineForInpainting,
2929
AutoPipelineForText2Image,
30+
AutoPipelineForText2Video,
3031
ControlNetModel,
3132
DiffusionPipeline,
3233
)
3334
from diffusers.pipelines.auto_pipeline import (
3435
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
3536
AUTO_INPAINT_PIPELINES_MAPPING,
3637
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
38+
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
3739
)
3840

3941
from ..testing_utils import slow

0 commit comments

Comments
 (0)