Skip to content

Commit 4f389e3

Browse files
committed
add autopipeline for text2video task
1 parent a748a83 commit 4f389e3

File tree

4 files changed

+260
-1
lines changed

4 files changed

+260
-1
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@
306306
"AutoPipelineForImage2Image",
307307
"AutoPipelineForInpainting",
308308
"AutoPipelineForText2Image",
309+
"AutoPipelineForText2Video",
309310
"ConsistencyModelPipeline",
310311
"DanceDiffusionPipeline",
311312
"DDIMPipeline",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForImage2Image",
4747
"AutoPipelineForInpainting",
4848
"AutoPipelineForText2Image",
49+
"AutoPipelineForText2Video",
4950
]
5051
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5152
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@
117117
StableDiffusionXLInpaintPipeline,
118118
StableDiffusionXLPipeline,
119119
)
120-
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
120+
from .text_to_video_synthesis import TextToVideoSDPipeline
121+
from .wan import WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
121122
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122123
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
123124

@@ -221,6 +222,9 @@
221222
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
222223
[
223224
("wan", WanPipeline),
225+
("wan-animate", WanAnimatePipeline),
226+
("wan-vace", WanVACEPipeline),
227+
("stable-diffusion", TextToVideoSDPipeline),
224228
]
225229
)
226230

@@ -1206,3 +1210,246 @@ def from_pipe(cls, pipeline, **kwargs):
12061210
model.register_to_config(**unused_original_config)
12071211

12081212
return model
1213+
1214+
1215+
class AutoPipelineForText2Video(ConfigMixin):
1216+
r"""
1217+
1218+
[`AutoPipelineForText2Video`] is a generic pipeline class that instantiates an text-to-video pipeline class. The
1219+
specific underlying pipeline class is automatically selected from either the
1220+
[`~AutoPipelineForText2Video.from_pretrained`] or [`~AutoPipelineForText2Video.from_pipe`] methods.
1221+
1222+
This class cannot be instantiated using `__init__()` (throws an error).
1223+
1224+
Class attributes:
1225+
1226+
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
1227+
diffusion pipeline's components.
1228+
1229+
"""
1230+
1231+
config_name = "model_index.json"
1232+
1233+
def __init__(self, *args, **kwargs):
1234+
raise EnvironmentError(
1235+
f"{self.__class__.__name__} is designed to be instantiated "
1236+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1237+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1238+
)
1239+
1240+
@classmethod
1241+
@validate_hf_hub_args
1242+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1243+
r"""
1244+
Instantiates a text-to-video Pytorch diffusion pipeline from pretrained pipeline weight.
1245+
1246+
The from_pretrained() method takes care of returning the correct pipeline class instance by:
1247+
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
1248+
config object
1249+
2. Find the text-to-video pipeline linked to the pipeline class using pattern matching on pipeline class name.
1250+
1251+
1252+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1253+
1254+
If you get the error message below, you need to finetune the weights for your downstream task:
1255+
1256+
```
1257+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1258+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1259+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1260+
```
1261+
1262+
Parameters:
1263+
pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
1264+
Can be either:
1265+
1266+
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
1267+
hosted on the Hub.
1268+
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
1269+
saved using
1270+
[`~DiffusionPipeline.save_pretrained`].
1271+
torch_dtype (`str` or `torch.dtype`, *optional*):
1272+
Override the default `torch.dtype` and load the model with another dtype.
1273+
force_download (`bool`, *optional*, defaults to `False`):
1274+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1275+
cached versions if they exist.
1276+
cache_dir (`Union[str, os.PathLike]`, *optional*):
1277+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1278+
is not used.
1279+
1280+
proxies (`Dict[str, str]`, *optional*):
1281+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1282+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1283+
output_loading_info(`bool`, *optional*, defaults to `False`):
1284+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1285+
local_files_only (`bool`, *optional*, defaults to `False`):
1286+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
1287+
won't be downloaded from the Hub.
1288+
token (`str` or *bool*, *optional*):
1289+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1290+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
1291+
revision (`str`, *optional*, defaults to `"main"`):
1292+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1293+
allowed by Git.
1294+
custom_revision (`str`, *optional*, defaults to `"main"`):
1295+
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
1296+
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
1297+
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1298+
mirror (`str`, *optional*):
1299+
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
1300+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1301+
information.
1302+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1303+
A map that specifies where each submodule should go. It doesn’t need to be defined for each
1304+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1305+
same device.
1306+
1307+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1308+
more information about each option see [designing a device
1309+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1310+
max_memory (`Dict`, *optional*):
1311+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1312+
each GPU and the available CPU RAM if unset.
1313+
offload_folder (`str` or `os.PathLike`, *optional*):
1314+
The path to offload weights if device_map contains the value `"disk"`.
1315+
offload_state_dict (`bool`, *optional*):
1316+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1317+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1318+
when there is some disk offload.
1319+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1320+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1321+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1322+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1323+
argument to `True` will raise an error.
1324+
use_safetensors (`bool`, *optional*, defaults to `None`):
1325+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1326+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1327+
weights. If set to `False`, safetensors weights are not loaded.
1328+
kwargs (remaining dictionary of keyword arguments, *optional*):
1329+
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
1330+
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
1331+
below for more information.
1332+
variant (`str`, *optional*):
1333+
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
1334+
loading `from_flax`.
1335+
1336+
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
1337+
with `hf > auth login`.
1338+
1339+
Examples:
1340+
1341+
```py
1342+
>>> from diffusers import AutoPipelineForText2Video
1343+
1344+
>>> pipeline = AutoPipelineForText2Video.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
1345+
>>> video_frames = pipe(prompt, num_frames=32).frames[0]
1346+
```
1347+
"""
1348+
1349+
cache_dir = kwargs.pop("cache_dir", None)
1350+
force_download = kwargs.pop("force_download", False)
1351+
proxies = kwargs.pop("proxies", None)
1352+
token = kwargs.pop("token", None)
1353+
local_files_only = kwargs.pop("local_files_only", False)
1354+
revision = kwargs.pop("revision", None)
1355+
1356+
load_config_kwargs = {
1357+
"cache_dir": cache_dir,
1358+
"force_download": force_download,
1359+
"proxies": proxies,
1360+
"token": token,
1361+
"local_files_only": local_files_only,
1362+
"revision": revision,
1363+
}
1364+
1365+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1366+
orig_class_name = config["_class_name"]
1367+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name)
1368+
kwargs = {**load_config_kwargs, **kwargs}
1369+
return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs)
1370+
1371+
@classmethod
1372+
def from_pipe(cls, pipeline, **kwargs):
1373+
r"""
1374+
Instantiates a text-to-video Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
1375+
1376+
The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-video
1377+
pipeline linked to the pipeline class using pattern matching on pipeline class name.
1378+
1379+
All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating
1380+
additional memory.
1381+
1382+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1383+
1384+
Parameters:
1385+
pipeline (`DiffusionPipeline`):
1386+
an instantiated `DiffusionPipeline` object
1387+
1388+
Examples:
1389+
1390+
```py
1391+
>>> from diffusers import AutoPipelineForText2Video
1392+
>>> pipeline = AutoPipelineForText2Video.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
1393+
>>> output = pipeline(prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames).frames[0]
1394+
```
1395+
"""
1396+
original_config = dict(pipeline.config)
1397+
original_cls_name = pipeline.__class__.__name__
1398+
1399+
# derive the pipeline class to instantiate
1400+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, original_cls_name)
1401+
1402+
# define expected module and optional kwargs given the pipeline signature
1403+
expected_modules, optional_kwargs = text_to_video_cls._get_signature_keys(text_to_video_cls)
1404+
1405+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1406+
1407+
# allow users pass modules in `kwargs` to override the original pipeline's components
1408+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1409+
original_class_obj = {
1410+
k: pipeline.components[k]
1411+
for k, v in pipeline.components.items()
1412+
if k in expected_modules and k not in passed_class_obj
1413+
}
1414+
1415+
# allow users pass optional kwargs to override the original pipelines config attribute
1416+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1417+
original_pipe_kwargs = {
1418+
k: original_config[k]
1419+
for k, v in original_config.items()
1420+
if k in optional_kwargs and k not in passed_pipe_kwargs
1421+
}
1422+
1423+
# config that were not expected by original pipeline is stored as private attribute
1424+
# we will pass them as optional arguments if they can be accepted by the pipeline
1425+
additional_pipe_kwargs = [
1426+
k[1:]
1427+
for k in original_config.keys()
1428+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1429+
]
1430+
for k in additional_pipe_kwargs:
1431+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1432+
1433+
text_to_video_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
1434+
1435+
# store unused config as private attribute
1436+
unused_original_config = {
1437+
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
1438+
for k, v in original_config.items()
1439+
if k not in text_to_video_kwargs
1440+
}
1441+
1442+
missing_modules = (
1443+
set(expected_modules) - set(text_to_video_cls._optional_components) - set(text_to_video_kwargs.keys())
1444+
)
1445+
1446+
if len(missing_modules) > 0:
1447+
raise ValueError(
1448+
f"Pipeline {text_to_video_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
1449+
)
1450+
1451+
model = text_to_video_cls(**text_to_video_kwargs)
1452+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1453+
model.register_to_config(**unused_original_config)
1454+
1455+
return model

tests/pipelines/test_pipelines_auto.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
AutoPipelineForImage2Image,
2828
AutoPipelineForInpainting,
2929
AutoPipelineForText2Image,
30+
AutoPipelineForText2Video,
3031
ControlNetModel,
3132
DiffusionPipeline,
3233
)
@@ -455,6 +456,15 @@ def test_from_pipe_optional_components(self):
455456
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
456457
assert pipe.image_encoder is None
457458

459+
def test_from_pretrained_text_to_video(self):
460+
repo = "hf-internal-testing/tiny-stable-diffusion-pipe"
461+
462+
pipe = AutoPipelineForText2Video.from_pretrained(repo)
463+
assert pipe.__class__.__name__ == "TextToVideoSDPipeline"
464+
465+
pipe = AutoPipelineForText2Video.from_pipe(pipe)
466+
assert pipe.__class__.__name__ == "TextToVideoSDPipeline"
467+
458468

459469
@slow
460470
class AutoPipelineIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)