Compare commits

...

6 Commits

Author SHA1 Message Date
zmey42
764fecbca3
Merge fe59c68fb4 into e5a741309d 2025-05-21 21:25:53 +07:00
Shiwei Zhang
e5a741309d
Update README.md (#406) 2025-05-17 10:57:06 +08:00
Ang Wang
76e9427657
Format the code (#402)
* isort the code

* format the code

* Add yapf config file

* Remove torch cuda memory profiler
2025-05-16 12:35:38 +08:00
Zhen Han
c709fcf0e7
fix vace size (#397) 2025-05-14 22:01:45 +08:00
Ang Wang
18d53feb7a
[feature] Add VACE (#389)
* Add VACE

* Support training with multiple gpus

* Update default args for vace task

* vace block update

* Add vace exmaple jpg

* Fix dist vace fwd hook error

* Update vace exmample

* Update vace args

* Update pipeline name for vace

* vace gradio and Readme

* Update vace snake png

---------

Co-authored-by: hanzhn <han.feng.jason@gmail.com>
2025-05-14 20:44:25 +08:00
zmey42
fe59c68fb4
Create Тг подарки 2025-03-14 20:51:20 +03:00
30 changed files with 2566 additions and 144 deletions

393
.style.yapf Normal file
View File

@ -0,0 +1,393 @@
[style]
# Align closing bracket with visual indentation.
align_closing_bracket_with_visual_indent=False
# Allow dictionary keys to exist on multiple lines. For example:
#
# x = {
# ('this is the first element of a tuple',
# 'this is the second element of a tuple'):
# value,
# }
allow_multiline_dictionary_keys=False
# Allow lambdas to be formatted on more than one line.
allow_multiline_lambdas=False
# Allow splitting before a default / named assignment in an argument list.
allow_split_before_default_or_named_assigns=False
# Allow splits before the dictionary value.
allow_split_before_dict_value=True
# Let spacing indicate operator precedence. For example:
#
# a = 1 * 2 + 3 / 4
# b = 1 / 2 - 3 * 4
# c = (1 + 2) * (3 - 4)
# d = (1 - 2) / (3 + 4)
# e = 1 * 2 - 3
# f = 1 + 2 + 3 + 4
#
# will be formatted as follows to indicate precedence:
#
# a = 1*2 + 3/4
# b = 1/2 - 3*4
# c = (1+2) * (3-4)
# d = (1-2) / (3+4)
# e = 1*2 - 3
# f = 1 + 2 + 3 + 4
#
arithmetic_precedence_indication=False
# Number of blank lines surrounding top-level function and class
# definitions.
blank_lines_around_top_level_definition=2
# Insert a blank line before a class-level docstring.
blank_line_before_class_docstring=False
# Insert a blank line before a module docstring.
blank_line_before_module_docstring=False
# Insert a blank line before a 'def' or 'class' immediately nested
# within another 'def' or 'class'. For example:
#
# class Foo:
# # <------ this blank line
# def method():
# ...
blank_line_before_nested_class_or_def=True
# Do not split consecutive brackets. Only relevant when
# dedent_closing_brackets is set. For example:
#
# call_func_that_takes_a_dict(
# {
# 'key1': 'value1',
# 'key2': 'value2',
# }
# )
#
# would reformat to:
#
# call_func_that_takes_a_dict({
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
# The column limit.
column_limit=80
# The style for continuation alignment. Possible values are:
#
# - SPACE: Use spaces for continuation alignment. This is default behavior.
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
# cannot vertically align continuation lines with indent characters.
continuation_align_style=SPACE
# Indent width used for line continuations.
continuation_indent_width=4
# Put closing brackets on a separate line, dedented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is dedented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is dedented and on a separate line
dedent_closing_brackets=False
# Disable the heuristic which places each list element on a separate line
# if the list is comma-terminated.
disable_ending_comma_heuristic=False
# Place each dictionary entry onto its own line.
each_dict_entry_on_separate_line=True
# Require multiline dictionary even if it would normally fit on one line.
# For example:
#
# config = {
# 'key1': 'value1'
# }
force_multiline_dict=False
# The regex for an i18n comment. The presence of this comment stops
# reformatting of that line, because the comments are required to be
# next to the string they translate.
i18n_comment=#\..*
# The i18n function call names. The presence of this function stops
# reformattting on that line, because the string it has cannot be moved
# away from the i18n comment.
i18n_function_call=N_, _
# Indent blank lines.
indent_blank_lines=False
# Put closing brackets on a separate line, indented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is indented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is indented and on a separate line
indent_closing_brackets=False
# Indent the dictionary value if it cannot fit on the same line as the
# dictionary key. For example:
#
# config = {
# 'key1':
# 'value1',
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=True
# The number of columns to use for indentation.
indent_width=4
# Join short lines into one line. E.g., single line 'if' statements.
join_multiple_lines=False
# Do not include spaces around selected binary operators. For example:
#
# 1 + 2 * 3 - 4 / 5
#
# will be formatted as follows when configured with "*,/":
#
# 1 + 2*3 - 4/5
no_spaces_around_selected_binary_operators=
# Use spaces around default or named assigns.
spaces_around_default_or_named_assign=False
# Adds a space after the opening '{' and before the ending '}' dict delimiters.
#
# {1: 2}
#
# will be formatted as:
#
# { 1: 2 }
spaces_around_dict_delimiters=False
# Adds a space after the opening '[' and before the ending ']' list delimiters.
#
# [1, 2]
#
# will be formatted as:
#
# [ 1, 2 ]
spaces_around_list_delimiters=False
# Use spaces around the power operator.
spaces_around_power_operator=False
# Use spaces around the subscript / slice operator. For example:
#
# my_list[1 : 10 : 2]
spaces_around_subscript_colon=False
# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
#
# (1, 2, 3)
#
# will be formatted as:
#
# ( 1, 2, 3 )
spaces_around_tuple_delimiters=False
# The number of spaces required before a trailing comment.
# This can be a single value (representing the number of spaces
# before each trailing comment) or list of values (representing
# alignment column values; trailing comments within a block will
# be aligned to the first column value that is greater than the maximum
# line length within the block). For example:
#
# With spaces_before_comment=5:
#
# 1 + 1 # Adding values
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
#
# With spaces_before_comment=15, 20:
#
# 1 + 1 # Adding values
# two + two # More adding
#
# longer_statement # This is a longer statement
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
# short # This is a shorter statement
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
# two + two # More adding
#
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
# short # This is a shorter statement
#
spaces_before_comment=2
# Insert a space between the ending comma and closing bracket of a list,
# etc.
space_between_ending_comma_and_closing_bracket=False
# Use spaces inside brackets, braces, and parentheses. For example:
#
# method_call( 1 )
# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
# my_set = { 1, 2, 3 }
space_inside_brackets=False
# Split before arguments
split_all_comma_separated_values=False
# Split before arguments, but do not split all subexpressions recursively
# (unless needed).
split_all_top_level_comma_separated_values=False
# Split before arguments if the argument list is terminated by a
# comma.
split_arguments_when_comma_terminated=False
# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
# rather than after.
split_before_arithmetic_operator=False
# Set to True to prefer splitting before '&', '|' or '^' rather than
# after.
split_before_bitwise_operator=False
# Split before the closing bracket if a list or dict literal doesn't fit on
# a single line.
split_before_closing_bracket=True
# Split before a dictionary or set generator (comp_for). For example, note
# the split before the 'for':
#
# foo = {
# variable: 'Hello world, have a nice day!'
# for variable in bar if variable != 42
# }
split_before_dict_set_generator=False
# Split before the '.' if we need to split a longer expression:
#
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
#
# would reformat to something like:
#
# foo = ('This is a really long string: {}, {}, {}, {}'
# .format(a, b, c, d))
split_before_dot=False
# Split after the opening paren which surrounds an expression if it doesn't
# fit on a single line.
split_before_expression_after_opening_paren=True
# If an argument / parameter list is going to be split, then split before
# the first argument.
split_before_first_argument=False
# Set to True to prefer splitting before 'and' or 'or' rather than
# after.
split_before_logical_operator=False
# Split named assignments onto individual lines.
split_before_named_assigns=True
# Set to True to split list comprehensions and generators that have
# non-trivial expressions and multiple clauses before each of these
# clauses. For example:
#
# result = [
# a_long_var + 100 for a_long_var in xrange(1000)
# if a_long_var % 10]
#
# would reformat to something like:
#
# result = [
# a_long_var + 100
# for a_long_var in xrange(1000)
# if a_long_var % 10]
split_complex_comprehension=True
# The penalty for splitting right after the opening bracket.
split_penalty_after_opening_bracket=300
# The penalty for splitting the line after a unary operator.
split_penalty_after_unary_operator=10000
# The penalty of splitting the line around the '+', '-', '*', '/', '//',
# ``%``, and '@' operators.
split_penalty_arithmetic_operator=300
# The penalty for splitting right before an if expression.
split_penalty_before_if_expr=0
# The penalty of splitting the line around the '&', '|', and '^'
# operators.
split_penalty_bitwise_operator=300
# The penalty for splitting a list comprehension or generator
# expression.
split_penalty_comprehension=2100
# The penalty for characters over the column limit.
split_penalty_excess_character=7000
# The penalty incurred by adding a line split to the unwrapped line. The
# more line splits added the higher the penalty.
split_penalty_for_added_line_split=30
# The penalty of splitting a list of "import as" names. For example:
#
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
# long_argument_2,
# long_argument_3)
#
# would reformat to something like:
#
# from a_very_long_or_indented_module_name_yada_yad import (
# long_argument_1, long_argument_2, long_argument_3)
split_penalty_import_names=0
# The penalty of splitting the line around the 'and' and 'or'
# operators.
split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: format
format:
isort generate.py gradio wan
yapf -i -r *.py generate.py gradio wan

View File

@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
- [ ] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 VACE
- [x] Multi-GPU Inference code of the 14B and 1.3B models
- [x] Checkpoints of the 14B and 1.3B models
- [x] Gradio demo
- [x] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
## Quickstart
@ -85,12 +92,14 @@ pip install -r requirements.txt
#### Model Download
| Models | Download Link | Notes |
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
> 💡Note:
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
@ -448,6 +457,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
```
#### Run VACE
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P(~81x480x832)</th>
<th>720P(~81x720x1280)</th>
</tr>
</thead>
<tbody>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td>Wan2.1-VACE-14B</td>
</tr>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: red; text-align: center; vertical-align: middle;"></td>
<td>Wan2.1-VACE-1.3B</td>
</tr>
</tbody>
</table>
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
The execution process is as follows:
##### (1) Preprocessing
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
##### (2) cli inference
- Single-GPU inference
```sh
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
##### (3) Running local gradio
- Single-GPU inference
```sh
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
```
#### Run Text-to-Image Generation
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
@ -567,7 +643,7 @@ If you find our work helpful, please cite us.
```
@article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314},
year={2025}
}

BIN
examples/girl.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 817 KiB

BIN
examples/snake.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 435 KiB

View File

@ -1,28 +1,33 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import torch, random
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
@ -41,6 +46,18 @@ EXAMPLE_PROMPT = {
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
@ -52,13 +69,15 @@ def _validate_args(args):
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
args.sample_steps = 50
if "i2v" in args.task:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
if "flf2v" in args.task:
elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
@ -141,6 +160,22 @@ def _parse_args():
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.")
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None."
)
parser.add_argument(
"--prompt",
type=str,
@ -182,12 +217,14 @@ def _parse_args():
"--first_frame",
type=str,
default=None,
help="[first-last frame to video] The image (first frame) to generate the video from.")
help="[first-last frame to video] The image (first frame) to generate the video from."
)
parser.add_argument(
"--last_frame",
type=str,
default=None,
help="[first-last frame to video] The image (last frame) to generate the video from.")
help="[first-last frame to video] The image (last frame) to generate the video from."
)
parser.add_argument(
"--sample_solver",
type=str,
@ -254,8 +291,10 @@ def generate(args):
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
@ -268,7 +307,8 @@ def generate(args):
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
@ -397,7 +437,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
@ -455,9 +495,65 @@ def generate(args):
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model
offload_model=args.offload_model)
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
"src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...")
if rank == 0:
prompt = prompt_expander.forward(args.prompt)
logging.info(
f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.")
wan_vace = wan.WanVace(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
[args.src_video], [args.src_mask], [
None if args.src_ref_images is None else
args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
raise ValueError(f"Unkown task type: {args.task}")
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang):
return prompt_output.prompt
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed,
n_prompt):
if resolution == '------':
print(
'Please specify the resolution ckpt dir or specify the resolution'
)
'Please specify the resolution ckpt dir or specify the resolution')
return None
else:
@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re
offload_model=True)
pass
else:
print(
'Sorry, currently only 720P is supported.'
)
print('Sorry, currently only 720P is supported.')
return None
cache_video(
@ -191,14 +190,17 @@ def gradio_interface():
run_p_button.click(
fn=prompt_enc,
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
tar_lang
],
outputs=[flf2vid_prompt])
run_flf2v_button.click(
fn=flf2v_generation,
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

349
gradio/vace.py Normal file
View File

@ -0,0 +1,349 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import datetime
import os
import sys
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self,
cfg,
skip_load=False,
gallery_share=True,
gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
if not skip_load:
if not args.mp:
self.pipe = WanVace(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
else:
self.pipe = WanVaceMP(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
use_usp=True,
ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
</div>
""")
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
self.src_video = gr.Video(
label="src_video",
sources=['upload'],
value=None,
interactive=True)
with gr.Column(scale=1, min_width=0):
self.src_mask = gr.Video(
label="src_mask",
sources=['upload'],
value=None,
interactive=True)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(
label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(
label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(
label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
show_label=False,
placeholder="positive_prompt_input",
elem_id='positive_prompt',
container=True,
autofocus=True,
elem_classes='type_row',
visible=True,
lines=2)
self.negative_prompt = gr.Textbox(
show_label=False,
value=self.pipe.config.sample_neg_prompt,
placeholder="negative_prompt_input",
elem_id='negative_prompt',
container=True,
autofocus=False,
elem_classes='type_row',
visible=True,
interactive=True,
lines=1)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.shift_scale = gr.Slider(
label='shift_scale',
minimum=0.0,
maximum=100.0,
step=1.0,
value=16.0,
interactive=True)
self.sample_steps = gr.Slider(
label='sample_steps',
minimum=1,
maximum=100,
step=1,
value=25,
interactive=True)
self.context_scale = gr.Slider(
label='context_scale',
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
interactive=True)
self.guide_scale = gr.Slider(
label='guide_scale',
minimum=1,
maximum=10,
step=0.5,
value=5.0,
interactive=True)
self.infer_seed = gr.Slider(
minimum=-1, maximum=10000000, value=2025, label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
self.output_height = gr.Textbox(
label='resolutions_height',
# value=480,
value=720,
interactive=True)
self.output_width = gr.Textbox(
label='resolutions_width',
# value=832,
value=1280,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate', value=16, interactive=True)
self.num_frames = gr.Textbox(
label='num_frames', value=81, interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
self.generate_button = gr.Button(
value='Run',
elem_classes='type_row',
elem_id='generate_button',
visible=True)
with gr.Column(scale=1):
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
#
self.output_gallery = gr.Gallery(
label="output_gallery",
value=[],
interactive=False,
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
shift_scale, sample_steps, context_scale, guide_scale,
infer_seed, output_height, output_width, frame_rate,
num_frames):
output_height, output_width, frame_rate, num_frames = int(
output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
if x is not None
]
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
[src_video], [src_mask], [src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
video = self.pipe.generate(
prompt,
src_video,
src_mask,
src_ref_images,
size=(output_width, output_height),
context_scale=context_scale,
shift=shift_scale,
sampling_steps=sample_steps,
guide_scale=guide_scale,
n_prompt=negative_prompt,
seed=infer_seed,
offload_model=True)
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(
video_path,
fps=frame_rate,
codec='libx264',
quality=8,
macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
print(video_path)
except Exception as e:
raise gr.Error(f"Video save error: {e}")
if self.gallery_share:
self.gallery_share_data.add(video_path)
return self.gallery_share_data.get()
else:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [
self.output_gallery, self.src_video, self.src_mask,
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
self.prompt, self.negative_prompt, self.shift_scale,
self.sample_steps, self.context_scale, self.guide_scale,
self.infer_seed, self.output_height, self.output_width,
self.frame_rate, self.num_frames
]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(
self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(
lambda x: self.gallery_share_data.get()
if self.gallery_share else x,
inputs=[self.output_gallery],
outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Argparser for VACE-WAN Demo:\n')
parser.add_argument(
'--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument(
'--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument(
"--mp",
action="store_true",
help="Use Multi-GPUs",
)
parser.add_argument(
"--model_name",
type=str,
default="vace-14B",
choices=list(WAN_CONFIGS.keys()),
help="The model name to run.")
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--ckpt_dir",
type=str,
# default='models/VACE-Wan2.1-1.3B-Preview',
default='models/Wan2.1-VACE-14B/',
help="The path to the checkpoint directory.",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(
server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True,
debug=True)

View File

@ -105,9 +105,16 @@ function i2v_14B_720p() {
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
function vace_1_3B() {
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
vace_1_3B

View File

@ -1,4 +1,5 @@
from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP

View File

@ -22,7 +22,9 @@ WAN_CONFIGS = {
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
'flf2v-14B': flf2v_14B
'flf2v-14B': flf2v_14B,
'vace-1.3B': t2v_1_3B,
'vace-14B': t2v_14B,
}
SIZE_CONFIGS = {
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
'vace-1.3B': ('480*832', '832*480'),
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
}

View File

@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
@ -32,6 +33,7 @@ def shard_model(
sync_module_states=sync_module_states)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):

View File

@ -1,9 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank,
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
@ -63,12 +65,39 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float()
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
# Context Parallel
c = torch.chunk(
c, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
vace_context=None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
@ -84,7 +113,7 @@ def usp_dit_forward(
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
@ -114,7 +143,7 @@ def usp_dit_forward(
for u in context
]))
if clip_fea is not None:
if self.model_type != 'vace' and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
@ -132,6 +161,11 @@ def usp_dit_forward(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if self.model_type == 'vace':
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanFLF2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -181,8 +185,10 @@ class WanFLF2V:
"""
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
self.device)
F = frame_num
first_frame_h, first_frame_w = first_frame.shape[1:]
@ -199,8 +205,7 @@ class WanFLF2V:
# 1. resize
last_frame_resize_ratio = max(
first_frame_size[0] / last_frame_size[0],
first_frame_size[1] / last_frame_size[1]
)
first_frame_size[1] / last_frame_size[1])
last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio),
@ -216,8 +221,7 @@ class WanFLF2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // 4 + 1,
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
@ -226,7 +230,10 @@ class WanFLF2V:
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1:-1] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
@ -247,7 +254,8 @@ class WanFLF2V:
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
clip_context = self.clip.visual(
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
@ -256,15 +264,14 @@ class WanFLF2V:
torch.nn.functional.interpolate(
first_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
mode='bicubic').transpose(0, 1),
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
torch.nn.functional.interpolate(
last_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
], dim=1).to(self.device)
mode='bicubic').transpose(0, 1),
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanI2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -196,8 +200,7 @@ class WanI2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // 4 + 1,
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,

View File

@ -2,11 +2,13 @@ from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vace_model import VaceWanModel
from .vae import WanVAE
__all__ = [
'WanVAE',
'WanModel',
'VaceWanModel',
'T5Model',
'T5Encoder',
'T5Decoder',

View File

@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
self.emb_pos = nn.Parameter(
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
def forward(self, image_embeds):
if hasattr(self, 'emb_pos'):
@ -400,7 +401,7 @@ class WanModel(ModelMixin, ConfigMixin):
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@ -433,7 +434,7 @@ class WanModel(ModelMixin, ConfigMixin):
super().__init__()
assert model_type in ['t2v', 'i2v', 'flf2v']
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
self.model_type = model_type
self.patch_size = patch_size

250
wan/modules/vace_model.py Normal file
View File

@ -0,0 +1,250 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c, c_skip
class BaseWanAttentionBlock(WanAttentionBlock):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=None):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
qk_norm, cross_attn_norm, eps)
self.block_id = block_id
def forward(self, x, hints, context_scale=1.0, **kwargs):
x = super().forward(x, **kwargs)
if self.block_id is not None:
x = x + hints[self.block_id] * context_scale
return x
class VaceWanModel(WanModel):
@register_to_config
def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='vace',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
freq_dim, text_dim, out_dim, num_heads, num_layers,
window_size, qk_norm, cross_attn_norm, eps)
self.vace_layers = [i for i in range(0, self.num_layers, 2)
] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {
i: n for n, i in enumerate(self.vace_layers)
}
# blocks
self.blocks = nn.ModuleList([
BaseWanAttentionBlock(
't2v_cross_attn',
self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=self.vace_layers_mapping[i]
if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock(
't2v_cross_attn',
self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=i) for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size)
def forward_vace(self, x, vace_context, seq_len, kwargs):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def forward(
self,
x,
t,
vace_context,
context,
seq_len,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# if self.model_type == 'i2v':
# assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
# if y is not None:
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# if clip_fea is not None:
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
# context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]

View File

@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -85,11 +88,12 @@ class WanT2V:
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)

View File

@ -1,8 +1,13 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor'
]

View File

@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor

View File

@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available():

View File

@ -7,7 +7,7 @@ import sys
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union, List
from typing import List, Optional, Union
import dashscope
import torch
@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''Directly output the rewritten English text.'''
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师旨在参考用户输入的图像的细节内容把用户输入的Prompt改写为优质Prompt使其更完整、更具表现力同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写严格参考示例的格式进行改写
任务要求
1. 用户会输入两张图片第一张是视频的第一帧第二张时视频的最后一帧你需要综合两个照片的内容进行优化改写
@ -198,8 +197,8 @@ class PromptExpander:
if system_prompt is None:
system_prompt = self.decide_system_prompt(
tar_lang=tar_lang,
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
)
multi_images_input=isinstance(image, (list, tuple)) and
len(image) > 1)
if seed < 0:
seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl:
@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1,
*args,
**kwargs):
@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander):
_image.save(f.name)
image_path = f"file://{f.name}"
return image_path
if not isinstance(image, (list, tuple)):
image = [image]
image_path_list = [ensure_image(_image) for _image in image]
role_content = [
{"text": prompt},
*[{"image": image_path} for image_path in image_path_list]
]
role_content = [{
"text": prompt
}, *[{
"image": image_path
} for image_path in image_path_list]]
system_content = [{"text": system_prompt}]
prompt = f"{prompt}"
messages = [
@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander):
if self.is_vl:
# default: Load the model on the available device(s)
from transformers import (AutoProcessor, AutoTokenizer,
Qwen2_5_VLForConditionalGeneration)
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try:
from .qwen_vl_utils import process_vision_info
except:
@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1,
*args,
**kwargs):
@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander):
if not isinstance(image, (list, tuple)):
image = [image]
system_content = [{
"type": "text",
"text": system_prompt
}]
role_content = [
{
system_content = [{"type": "text", "text": system_prompt}]
role_content = [{
"type": "text",
"text": prompt
},
*[
{"image": image_path} for image_path in image
]
]
}, *[{
"image": image_path
} for image_path in image]]
messages = [{
'role': 'system',
'content': system_content,
}, {
"role":
"user",
"role": "user",
"content": role_content,
}]
@ -611,25 +610,38 @@ if __name__ == "__main__":
print("VL qwen vl en result -> en",
qwen_result.prompt) # , qwen_result.system_prompt)
# test multi images
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
image = [
"./examples/flf2v_input_first_frame.png",
"./examples/flf2v_input_last_frame.png"
]
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
en_prompt = (
"Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
"architectural structures, combining to create a tranquil and breathtaking coastal landscape."
)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
dashscope_prompt_expander = DashScopePromptExpander(
model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope result -> zh", dashscope_result.prompt)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
dashscope_prompt_expander = DashScopePromptExpander(
model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
en_prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope en result -> zh", dashscope_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
qwen_prompt_expander = QwenPromptExpander(
model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen result -> zh", qwen_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
qwen_prompt_expander = QwenPromptExpander(
model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen en result -> zh", qwen_result.prompt)

305
wan/utils/vace_processor.py Normal file
View File

@ -0,0 +1,305 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
def _pillow_convert(self, image, cvt_type='RGB'):
if image.mode != cvt_type:
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(
cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
image = image.convert(cvt_type)
return image
def _load_image(self, img_path):
if img_path is None or img_path == '':
return None
img = Image.open(img_path)
img = self._pillow_convert(img)
return img
def _resize_crop(self, img, oh, ow, normalize=True):
"""
Resize, center crop, convert to tensor, and normalize.
"""
# resize and crop
iw, ih = img.size
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize((round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS)
assert img.width >= ow and img.height >= oh
# center crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
# normalize
if normalize:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
return img
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
return self._resize_crop(img, oh, ow, normalize)
def load_image(self, data_key, **kwargs):
return self.load_image_batch(data_key, **kwargs)
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self,
*data_key_batch,
normalize=True,
seq_len=None,
**kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
img = self._load_image(data_key)
imgs.append(img)
w, h = imgs[0].size
dh, dw = self.downsample[1:]
# compute output size
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
oh = int(h * scale) // dh * dh
ow = int(w * scale) // dw * dw
assert (oh // dh) * (ow // dw) <= seq_len
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
return *imgs, (oh, ow)
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
def set_area(self, area):
self.min_area = area
self.max_area = area
def set_seq_len(self, seq_len):
self.seq_len = seq_len
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw),
(h // dh) * (w // dw))
of = min((int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(
0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]),
axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
crop_box, rng):
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw),
(h // dh) * (w // dw))
of = min((len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = duration
target_fps = of / target_duration
timestamps = np.linspace(0., target_duration, of)
frame_ids = np.argmax(
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1]),
axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
w, crop_box, rng)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(
data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self,
data_key,
data_key2,
crop_box=None,
seed=2024,
**kwargs):
return self.load_video_batch(
data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self,
*data_key_batch,
crop_box=None,
seed=2024,
**kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [
readers[0].get_frame_timestamp(i) for i in range(length)
]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video
videos = [
reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
for reader in readers
]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros(
(3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones(
(1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones(
(3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images

797
wan/vace.py Normal file
View File

@ -0,0 +1,797 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import time
import traceback
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
WanT2V,
WanVAE,
get_sampling_sigmas,
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor
class WanVace(WanT2V):
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in self.model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace,
self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.vid_proc = VaceVideoProcessor(
downsample=tuple(
[x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720 * 1280,
max_area=720 * 1280,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=75600,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = vae.encode(frames)
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = vae.encode(inactive)
reactive = vae.encode(reactive)
latents = [
torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)
]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = vae.encode(refs)
else:
ref_latent = vae.encode(refs)
ref_latent = [
torch.cat((u, torch.zeros_like(u)), dim=0)
for u in ref_latent
]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
vae_stride = self.vae_stride if vae_stride is None else vae_stride
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // vae_stride[0])
height = 2 * (int(height) // (vae_stride[1] * 2))
width = 2 * (int(width) // (vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(depth, height, vae_stride[1], width,
vae_stride[1]) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height,
width) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(
mask.unsqueeze(0),
size=(new_depth, height, width),
mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
image_size, device):
area = image_size[0] * image_size[1]
self.vid_proc.set_area(area)
if area == 720 * 1280:
self.vid_proc.set_seq_len(75600)
elif area == 480 * 832:
self.vid_proc.set_seq_len(32760)
else:
raise NotImplementedError(
f'image_size {image_size} is not supported')
image_size = (image_size[1], image_size[0])
image_sizes = []
for i, (sub_src_video,
sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[
i], _, _, _ = self.vid_proc.load_video_pair(
sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp(
(src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros(
(3, num_frames, image_size[0], image_size[1]),
device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(
0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones(
(3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return vae.decode(trimed_zs)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
# F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
# size[1] // self.vae_stride[1],
# size[0] // self.vae_stride[2])
#
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
# (self.patch_size[1] * self.patch_size[2]) *
# target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
# vace context encode
z0 = self.vace_encode_frames(
input_frames, input_ref_images, masks=input_masks)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.decode_latent(x0, input_ref_images)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
class WanVaceMP(WanVace):
def __init__(self,
config,
checkpoint_dir,
use_usp=False,
ulysses_size=None,
ring_size=None):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.use_usp = use_usp
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
self.in_q_list = None
self.out_q = None
self.inference_pids = None
self.ulysses_size = ulysses_size
self.ring_size = ring_size
self.dynamic_load()
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
self.vid_proc = VaceVideoProcessor(
downsample=tuple(
[x * y for x, y in zip(config.vae_stride, config.patch_size)]),
min_area=480 * 832,
max_area=480 * 832,
min_fps=self.config.sample_fps,
max_fps=self.config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
def dynamic_load(self):
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
return
gpu_infer = os.environ.get(
'LOCAL_WORLD_SIZE') or torch.cuda.device_count()
pmi_rank = int(os.environ['RANK'])
pmi_world_size = int(os.environ['WORLD_SIZE'])
in_q_list = [
torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)
]
out_q = torch.multiprocessing.Manager().Queue()
initialized_events = [
torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)
]
context = mp.spawn(
self.mp_worker,
nprocs=gpu_infer,
args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q,
initialized_events, self),
join=False)
all_initialized = False
while not all_initialized:
all_initialized = all(
event.is_set() for event in initialized_events)
if not all_initialized:
time.sleep(0.1)
print('Inference model is initialized', flush=True)
self.in_q_list = in_q_list
self.out_q = out_q
self.inference_pids = context.pids()
self.initialized_events = initialized_events
def transfer_data_to_cuda(self, data, device):
if data is None:
return None
else:
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, list):
data = [
self.transfer_data_to_cuda(subdata, device)
for subdata in data
]
elif isinstance(data, dict):
data = {
key: self.transfer_data_to_cuda(val, device)
for key, val in data.items()
}
return data
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
out_q, initialized_events, work_env):
try:
world_size = pmi_world_size * gpu_infer
rank = pmi_rank * gpu_infer + gpu
print("world_size", world_size, "rank", rank, flush=True)
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
rank=rank,
world_size=world_size)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.ring_size or 1,
ulysses_degree=self.ulysses_size or 1)
num_train_timesteps = self.config.num_train_timesteps
param_dtype = self.config.param_dtype
shard_fn = partial(shard_model, device_id=gpu)
text_encoder = T5EncoderModel(
text_len=self.config.text_len,
dtype=self.config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(self.checkpoint_dir,
self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir,
self.config.t5_tokenizer),
shard_fn=shard_fn if True else None)
text_encoder.model.to(gpu)
vae_stride = self.config.vae_stride
patch_size = self.config.patch_size
vae = WanVAE(
vae_pth=os.path.join(self.checkpoint_dir,
self.config.vae_checkpoint),
device=gpu)
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
model.eval().requires_grad_(False)
if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
model.forward = types.MethodType(usp_dit_forward, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace,
model)
sp_size = get_sequence_parallel_world_size()
else:
sp_size = 1
dist.barrier()
model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache()
event = initialized_events[gpu]
in_q = in_q_list[gpu]
event.set()
while True:
item = in_q.get()
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
input_ref_images = self.transfer_data_to_cuda(
input_ref_images, gpu)
if n_prompt == "":
n_prompt = sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=gpu)
seed_g.manual_seed(seed)
context = text_encoder([input_prompt], gpu)
context_null = text_encoder([n_prompt], gpu)
# vace context encode
z0 = self.vace_encode_frames(
input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(
input_masks, input_ref_images, vae_stride=vae_stride)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=gpu,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(patch_size[1] * patch_size[2]) *
target_shape[1] / sp_size) * sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(
dtype=param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=gpu, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(
sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=gpu,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
model.to(gpu)
noise_pred_cond = model(
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = model(
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(
x0, input_ref_images, vae=vae)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
if rank == 0:
out_q.put(videos[0].cpu())
except Exception as e:
trace_info = traceback.format_exc()
print(trace_info, flush=True)
print(e, flush=True)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
input_data = (input_prompt, input_frames, input_masks, input_ref_images,
size, frame_num, context_scale, shift, sample_solver,
sampling_steps, guide_scale, n_prompt, seed,
offload_model)
for in_q in self.in_q_list:
in_q.put(input_data)
value_output = self.out_q.get()
return value_output

59
Тг подарки Normal file
View File

@ -0,0 +1,59 @@
from manim import *
class NFTPresentation(Scene):
def construct(self):
# 1. Анимация отправки NFT-подарка
phone = SVGMobject("smartphone") # Загрузите SVG-изображение телефона
chat_bubble = Text("Отправляю NFT-подарок!", font_size=24)
nft_gift = ImageMobject("nft_gift.png") # Загрузите изображение NFT-подарка
phone.scale(0.8)
chat_bubble.next_to(phone, UP)
nft_gift.scale(0.5).next_to(chat_bubble, UP)
self.play(DrawBorderThenFill(phone))
self.play(Write(chat_bubble))
self.play(FadeIn(nft_gift))
self.wait(2)
# 2. Примеры уникальных цифровых подарков
art = ImageMobject("art.png") # Загрузите изображение арта
card = ImageMobject("card.png") # Загрузите изображение коллекционной карточки
animation = ImageMobject("animation.gif") # Загрузите GIF-анимацию
art.scale(0.5).to_edge(LEFT)
card.scale(0.5).next_to(art, RIGHT)
animation.scale(0.5).next_to(card, RIGHT)
self.play(FadeIn(art), FadeIn(card), FadeIn(animation))
self.wait(3)
# 3. Преимущества NFT-подарков
advantages = VGroup(
Text("Уникальность", font_size=24),
Text("Возможность перепродажи", font_size=24),
Text("Эмоциональная ценность", font_size=24)
).arrange(DOWN, aligned_edge=LEFT)
advantages.next_to(phone, DOWN)
self.play(Write(advantages))
self.wait(3)
# 4. Призыв
call_to_action = Text("Дарите уникальное! NFT-подарки в Telegram — тренд будущего!", font_size=28)
call_to_action.to_edge(UP)
self.play(Write(call_to_action))
self.wait(2)
# 5. Логотип Telegram и хэштег
telegram_logo = ImageMobject("telegram_logo.png") # Загрузите логотип Telegram
hashtag = Text("#https://t.me/TONNELNFT1", font_size=24)
telegram_logo.scale(0.5).to_edge(DOWN)
hashtag.next_to(telegram_logo, RIGHT)
self.play(FadeIn(telegram_logo), Write(hashtag))
self.wait(3)
# Для запуска анимации используйте команду:
# manim -pql script.py NFTPresentation