Format the code (#402)

* isort the code

* format the code

* Add yapf config file

* Remove torch cuda memory profiler
This commit is contained in:
Ang Wang 2025-05-16 12:35:38 +08:00 committed by GitHub
parent c709fcf0e7
commit 76e9427657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1052 additions and 416 deletions

393
.style.yapf Normal file
View File

@ -0,0 +1,393 @@
[style]
# Align closing bracket with visual indentation.
align_closing_bracket_with_visual_indent=False
# Allow dictionary keys to exist on multiple lines. For example:
#
# x = {
# ('this is the first element of a tuple',
# 'this is the second element of a tuple'):
# value,
# }
allow_multiline_dictionary_keys=False
# Allow lambdas to be formatted on more than one line.
allow_multiline_lambdas=False
# Allow splitting before a default / named assignment in an argument list.
allow_split_before_default_or_named_assigns=False
# Allow splits before the dictionary value.
allow_split_before_dict_value=True
# Let spacing indicate operator precedence. For example:
#
# a = 1 * 2 + 3 / 4
# b = 1 / 2 - 3 * 4
# c = (1 + 2) * (3 - 4)
# d = (1 - 2) / (3 + 4)
# e = 1 * 2 - 3
# f = 1 + 2 + 3 + 4
#
# will be formatted as follows to indicate precedence:
#
# a = 1*2 + 3/4
# b = 1/2 - 3*4
# c = (1+2) * (3-4)
# d = (1-2) / (3+4)
# e = 1*2 - 3
# f = 1 + 2 + 3 + 4
#
arithmetic_precedence_indication=False
# Number of blank lines surrounding top-level function and class
# definitions.
blank_lines_around_top_level_definition=2
# Insert a blank line before a class-level docstring.
blank_line_before_class_docstring=False
# Insert a blank line before a module docstring.
blank_line_before_module_docstring=False
# Insert a blank line before a 'def' or 'class' immediately nested
# within another 'def' or 'class'. For example:
#
# class Foo:
# # <------ this blank line
# def method():
# ...
blank_line_before_nested_class_or_def=True
# Do not split consecutive brackets. Only relevant when
# dedent_closing_brackets is set. For example:
#
# call_func_that_takes_a_dict(
# {
# 'key1': 'value1',
# 'key2': 'value2',
# }
# )
#
# would reformat to:
#
# call_func_that_takes_a_dict({
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
# The column limit.
column_limit=80
# The style for continuation alignment. Possible values are:
#
# - SPACE: Use spaces for continuation alignment. This is default behavior.
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
# cannot vertically align continuation lines with indent characters.
continuation_align_style=SPACE
# Indent width used for line continuations.
continuation_indent_width=4
# Put closing brackets on a separate line, dedented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is dedented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is dedented and on a separate line
dedent_closing_brackets=False
# Disable the heuristic which places each list element on a separate line
# if the list is comma-terminated.
disable_ending_comma_heuristic=False
# Place each dictionary entry onto its own line.
each_dict_entry_on_separate_line=True
# Require multiline dictionary even if it would normally fit on one line.
# For example:
#
# config = {
# 'key1': 'value1'
# }
force_multiline_dict=False
# The regex for an i18n comment. The presence of this comment stops
# reformatting of that line, because the comments are required to be
# next to the string they translate.
i18n_comment=#\..*
# The i18n function call names. The presence of this function stops
# reformattting on that line, because the string it has cannot be moved
# away from the i18n comment.
i18n_function_call=N_, _
# Indent blank lines.
indent_blank_lines=False
# Put closing brackets on a separate line, indented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is indented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is indented and on a separate line
indent_closing_brackets=False
# Indent the dictionary value if it cannot fit on the same line as the
# dictionary key. For example:
#
# config = {
# 'key1':
# 'value1',
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=True
# The number of columns to use for indentation.
indent_width=4
# Join short lines into one line. E.g., single line 'if' statements.
join_multiple_lines=False
# Do not include spaces around selected binary operators. For example:
#
# 1 + 2 * 3 - 4 / 5
#
# will be formatted as follows when configured with "*,/":
#
# 1 + 2*3 - 4/5
no_spaces_around_selected_binary_operators=
# Use spaces around default or named assigns.
spaces_around_default_or_named_assign=False
# Adds a space after the opening '{' and before the ending '}' dict delimiters.
#
# {1: 2}
#
# will be formatted as:
#
# { 1: 2 }
spaces_around_dict_delimiters=False
# Adds a space after the opening '[' and before the ending ']' list delimiters.
#
# [1, 2]
#
# will be formatted as:
#
# [ 1, 2 ]
spaces_around_list_delimiters=False
# Use spaces around the power operator.
spaces_around_power_operator=False
# Use spaces around the subscript / slice operator. For example:
#
# my_list[1 : 10 : 2]
spaces_around_subscript_colon=False
# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
#
# (1, 2, 3)
#
# will be formatted as:
#
# ( 1, 2, 3 )
spaces_around_tuple_delimiters=False
# The number of spaces required before a trailing comment.
# This can be a single value (representing the number of spaces
# before each trailing comment) or list of values (representing
# alignment column values; trailing comments within a block will
# be aligned to the first column value that is greater than the maximum
# line length within the block). For example:
#
# With spaces_before_comment=5:
#
# 1 + 1 # Adding values
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
#
# With spaces_before_comment=15, 20:
#
# 1 + 1 # Adding values
# two + two # More adding
#
# longer_statement # This is a longer statement
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
# short # This is a shorter statement
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
# two + two # More adding
#
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
# short # This is a shorter statement
#
spaces_before_comment=2
# Insert a space between the ending comma and closing bracket of a list,
# etc.
space_between_ending_comma_and_closing_bracket=False
# Use spaces inside brackets, braces, and parentheses. For example:
#
# method_call( 1 )
# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
# my_set = { 1, 2, 3 }
space_inside_brackets=False
# Split before arguments
split_all_comma_separated_values=False
# Split before arguments, but do not split all subexpressions recursively
# (unless needed).
split_all_top_level_comma_separated_values=False
# Split before arguments if the argument list is terminated by a
# comma.
split_arguments_when_comma_terminated=False
# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
# rather than after.
split_before_arithmetic_operator=False
# Set to True to prefer splitting before '&', '|' or '^' rather than
# after.
split_before_bitwise_operator=False
# Split before the closing bracket if a list or dict literal doesn't fit on
# a single line.
split_before_closing_bracket=True
# Split before a dictionary or set generator (comp_for). For example, note
# the split before the 'for':
#
# foo = {
# variable: 'Hello world, have a nice day!'
# for variable in bar if variable != 42
# }
split_before_dict_set_generator=False
# Split before the '.' if we need to split a longer expression:
#
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
#
# would reformat to something like:
#
# foo = ('This is a really long string: {}, {}, {}, {}'
# .format(a, b, c, d))
split_before_dot=False
# Split after the opening paren which surrounds an expression if it doesn't
# fit on a single line.
split_before_expression_after_opening_paren=True
# If an argument / parameter list is going to be split, then split before
# the first argument.
split_before_first_argument=False
# Set to True to prefer splitting before 'and' or 'or' rather than
# after.
split_before_logical_operator=False
# Split named assignments onto individual lines.
split_before_named_assigns=True
# Set to True to split list comprehensions and generators that have
# non-trivial expressions and multiple clauses before each of these
# clauses. For example:
#
# result = [
# a_long_var + 100 for a_long_var in xrange(1000)
# if a_long_var % 10]
#
# would reformat to something like:
#
# result = [
# a_long_var + 100
# for a_long_var in xrange(1000)
# if a_long_var % 10]
split_complex_comprehension=True
# The penalty for splitting right after the opening bracket.
split_penalty_after_opening_bracket=300
# The penalty for splitting the line after a unary operator.
split_penalty_after_unary_operator=10000
# The penalty of splitting the line around the '+', '-', '*', '/', '//',
# ``%``, and '@' operators.
split_penalty_arithmetic_operator=300
# The penalty for splitting right before an if expression.
split_penalty_before_if_expr=0
# The penalty of splitting the line around the '&', '|', and '^'
# operators.
split_penalty_bitwise_operator=300
# The penalty for splitting a list comprehension or generator
# expression.
split_penalty_comprehension=2100
# The penalty for characters over the column limit.
split_penalty_excess_character=7000
# The penalty incurred by adding a line split to the unwrapped line. The
# more line splits added the higher the penalty.
split_penalty_for_added_line_split=30
# The penalty of splitting a list of "import as" names. For example:
#
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
# long_argument_2,
# long_argument_3)
#
# would reformat to something like:
#
# from a_very_long_or_indented_module_name_yada_yad import (
# long_argument_1, long_argument_2, long_argument_3)
split_penalty_import_names=0
# The penalty of splitting the line around the 'and' and 'or'
# operators.
split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: format
format:
isort generate.py gradio wan
yapf -i -r *.py generate.py gradio wan

View File

@ -1,28 +1,33 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import torch, random
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
@ -34,20 +39,24 @@ EXAMPLE_PROMPT = {
"examples/i2v_input.JPG",
},
"flf2v-14B": {
"prompt":
"CG动画风格一只蓝色的小鸟从地面起飞煽动翅膀。小鸟羽毛细腻胸前有独特的花纹背景是蓝天白云阳光明媚。镜跟随小鸟向上移动展现出小鸟飞翔的姿态和天空的广阔。近景仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
"prompt":
"CG动画风格一只蓝色的小鸟从地面起飞煽动翅膀。小鸟羽毛细腻胸前有独特的花纹背景是蓝天白云阳光明媚。镜跟随小鸟向上移动展现出小鸟飞翔的姿态和天空的广阔。近景仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
@ -64,7 +73,6 @@ def _validate_args(args):
if "i2v" in args.task:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
@ -72,7 +80,6 @@ def _validate_args(args):
elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
@ -167,7 +174,8 @@ def _parse_args():
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None.")
help="The file list of the source reference images. Separated by ','. Default None."
)
parser.add_argument(
"--prompt",
type=str,
@ -209,12 +217,14 @@ def _parse_args():
"--first_frame",
type=str,
default=None,
help="[first-last frame to video] The image (first frame) to generate the video from.")
help="[first-last frame to video] The image (first frame) to generate the video from."
)
parser.add_argument(
"--last_frame",
type=str,
default=None,
help="[first-last frame to video] The image (last frame) to generate the video from.")
help="[first-last frame to video] The image (last frame) to generate the video from."
)
parser.add_argument(
"--sample_solver",
type=str,
@ -281,8 +291,10 @@ def generate(args):
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
@ -295,7 +307,8 @@ def generate(args):
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
@ -482,21 +495,22 @@ def generate(args):
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model
)
offload_model=args.offload_model)
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
"src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...")
if rank == 0:
prompt = prompt_expander.forward(args.prompt)
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
logging.info(
f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt]
else:
input_prompt = [None]
@ -517,10 +531,11 @@ def generate(args):
t5_cpu=args.t5_cpu,
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
[args.src_mask],
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
args.frame_num, SIZE_CONFIGS[args.size], device)
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
[args.src_video], [args.src_mask], [
None if args.src_ref_images is None else
args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
video = wan_vace.generate(

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang):
return prompt_output.prompt
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed,
n_prompt):
if resolution == '------':
print(
'Please specify the resolution ckpt dir or specify the resolution'
)
'Please specify the resolution ckpt dir or specify the resolution')
return None
else:
@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re
offload_model=True)
pass
else:
print(
'Sorry, currently only 720P is supported.'
)
print('Sorry, currently only 720P is supported.')
return None
cache_video(
@ -191,14 +190,17 @@ def gradio_interface():
run_p_button.click(
fn=prompt_enc,
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
tar_lang
],
outputs=[flf2vid_prompt])
run_flf2v_button.click(
fn=flf2v_generation,
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -2,36 +2,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import datetime
import os
import sys
import datetime
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
sys.path.insert(
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
def __init__(self,
cfg,
skip_load=False,
gallery_share=True,
gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
@ -53,9 +65,7 @@ class VACEInference:
checkpoint_dir=cfg.ckpt_dir,
use_usp=True,
ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size
)
ring_size=cfg.ring_size)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
@ -80,30 +90,33 @@ class VACEInference:
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
self.src_ref_image_1 = gr.Image(
label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(
label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(
label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
@ -158,10 +171,8 @@ class VACEInference:
step=0.5,
value=5.0,
interactive=True)
self.infer_seed = gr.Slider(minimum=-1,
maximum=10000000,
value=2025,
label="Seed")
self.infer_seed = gr.Slider(
minimum=-1, maximum=10000000, value=2025, label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
@ -176,13 +187,9 @@ class VACEInference:
value=1280,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate',
value=16,
interactive=True)
label='frame_rate', value=16, interactive=True)
self.num_frames = gr.Textbox(
label='num_frames',
value=81,
interactive=True)
label='num_frames', value=81, interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
@ -201,17 +208,22 @@ class VACEInference:
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
x is not None]
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
[src_mask],
[src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
shift_scale, sample_steps, context_scale, guide_scale,
infer_seed, output_height, output_width, frame_rate,
num_frames):
output_height, output_width, frame_rate, num_frames = int(
output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
if x is not None
]
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
[src_video], [src_mask], [src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
video = self.pipe.generate(
prompt,
src_video,
@ -228,10 +240,17 @@ class VACEInference:
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
video_frames = (
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
writer = imageio.get_writer(
video_path,
fps=frame_rate,
codec='libx264',
quality=8,
macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
@ -246,25 +265,57 @@ class VACEInference:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
self.gen_inputs = [
self.output_gallery, self.src_video, self.src_mask,
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
self.prompt, self.negative_prompt, self.shift_scale,
self.sample_steps, self.context_scale, self.guide_scale,
self.infer_seed, self.output_height, self.output_width,
self.frame_rate, self.num_frames
]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
self.generate_button.click(
self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(
lambda x: self.gallery_share_data.get()
if self.gallery_share else x,
inputs=[self.output_gallery],
outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
parser = argparse.ArgumentParser(
description='Argparser for VACE-WAN Demo:\n')
parser.add_argument(
'--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument(
'--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--mp",
action="store_true",
help="Use Multi-GPUs",
)
parser.add_argument(
"--model_name",
type=str,
default="vace-14B",
choices=list(WAN_CONFIGS.keys()),
help="The model name to run.")
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--ckpt_dir",
type=str,
@ -284,12 +335,15 @@ if __name__ == '__main__':
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr = VACEInference(
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True, debug=True)
demo.queue(status_update_rate=1).launch(
server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True,
debug=True)

View File

@ -1,5 +1,5 @@
from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP

View File

@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
@ -32,6 +33,7 @@ def shard_model(
sync_module_states=sync_module_states)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):

View File

@ -1,9 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
@ -63,19 +65,13 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float()
def usp_dit_forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in c
])
# arguments

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanFLF2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -181,8 +185,10 @@ class WanFLF2V:
"""
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
self.device)
F = frame_num
first_frame_h, first_frame_w = first_frame.shape[1:]
@ -199,8 +205,7 @@ class WanFLF2V:
# 1. resize
last_frame_resize_ratio = max(
first_frame_size[0] / last_frame_size[0],
first_frame_size[1] / last_frame_size[1]
)
first_frame_size[1] / last_frame_size[1])
last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio),
@ -216,8 +221,7 @@ class WanFLF2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // 4 + 1,
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
@ -225,8 +229,11 @@ class WanFLF2V:
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1: -1] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk[:, 1:-1] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
@ -247,7 +254,8 @@ class WanFLF2V:
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
clip_context = self.clip.visual(
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
@ -256,15 +264,14 @@ class WanFLF2V:
torch.nn.functional.interpolate(
first_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
mode='bicubic').transpose(0, 1),
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
torch.nn.functional.interpolate(
last_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
], dim=1).to(self.device)
mode='bicubic').transpose(0, 1),
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanI2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -196,8 +200,7 @@ class WanI2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // 4 + 1,
16, (F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,

View File

@ -273,7 +273,7 @@ class WanAttentionBlock(nn.Module):
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
@ -332,7 +332,7 @@ class Head(nn.Module):
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
self.emb_pos = nn.Parameter(
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
def forward(self, image_embeds):
if hasattr(self, 'emb_pos'):

View File

@ -3,23 +3,24 @@ import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
@ -39,19 +40,19 @@ class VaceWanAttentionBlock(WanAttentionBlock):
class BaseWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=None
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=None):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
qk_norm, cross_attn_norm, eps)
self.block_id = block_id
def forward(self, x, hints, context_scale=1.0, **kwargs):
@ -62,6 +63,7 @@ class BaseWanAttentionBlock(WanAttentionBlock):
class VaceWanModel(WanModel):
@register_to_config
def __init__(self,
vace_layers=None,
@ -81,42 +83,57 @@ class VaceWanModel(WanModel):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
freq_dim, text_dim, out_dim, num_heads, num_layers,
window_size, qk_norm, cross_attn_norm, eps)
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
self.vace_layers = [i for i in range(0, self.num_layers, 2)
] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
self.vace_layers_mapping = {
i: n for n, i in enumerate(self.vace_layers)
}
# blocks
self.blocks = nn.ModuleList([
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
BaseWanAttentionBlock(
't2v_cross_attn',
self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=self.vace_layers_mapping[i]
if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, block_id=i)
for i in self.vace_layers
VaceWanAttentionBlock(
't2v_cross_attn',
self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=i) for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.vace_in_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size)
def forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
def forward_vace(self, x, vace_context, seq_len, kwargs):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
@ -230,4 +247,4 @@ class VaceWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return [u.float() for u in x]

View File

@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -85,11 +88,12 @@ class WanT2V:
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)

View File

@ -1,5 +1,8 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor

View File

@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor

View File

@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available():

View File

@ -7,7 +7,7 @@ import sys
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union, List
from typing import List, Optional, Union
import dashscope
import torch
@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''Directly output the rewritten English text.'''
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师旨在参考用户输入的图像的细节内容把用户输入的Prompt改写为优质Prompt使其更完整、更具表现力同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写严格参考示例的格式进行改写
任务要求
1. 用户会输入两张图片第一张是视频的第一帧第二张时视频的最后一帧你需要综合两个照片的内容进行优化改写
@ -198,8 +197,8 @@ class PromptExpander:
if system_prompt is None:
system_prompt = self.decide_system_prompt(
tar_lang=tar_lang,
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
)
multi_images_input=isinstance(image, (list, tuple)) and
len(image) > 1)
if seed < 0:
seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl:
@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1,
*args,
**kwargs):
@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander):
_image.save(f.name)
image_path = f"file://{f.name}"
return image_path
if not isinstance(image, (list, tuple)):
image = [image]
image_path_list = [ensure_image(_image) for _image in image]
role_content = [
{"text": prompt},
*[{"image": image_path} for image_path in image_path_list]
]
role_content = [{
"text": prompt
}, *[{
"image": image_path
} for image_path in image_path_list]]
system_content = [{"text": system_prompt}]
prompt = f"{prompt}"
messages = [
@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander):
if self.is_vl:
# default: Load the model on the available device(s)
from transformers import (AutoProcessor, AutoTokenizer,
Qwen2_5_VLForConditionalGeneration)
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try:
from .qwen_vl_utils import process_vision_info
except:
@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1,
*args,
**kwargs):
@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander):
if not isinstance(image, (list, tuple)):
image = [image]
system_content = [{
system_content = [{"type": "text", "text": system_prompt}]
role_content = [{
"type": "text",
"text": system_prompt
}]
role_content = [
{
"type": "text",
"text": prompt
},
*[
{"image": image_path} for image_path in image
]
]
"text": prompt
}, *[{
"image": image_path
} for image_path in image]]
messages = [{
'role': 'system',
'content': system_content,
}, {
"role":
"user",
"role": "user",
"content": role_content,
}]
@ -611,25 +610,38 @@ if __name__ == "__main__":
print("VL qwen vl en result -> en",
qwen_result.prompt) # , qwen_result.system_prompt)
# test multi images
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
image = [
"./examples/flf2v_input_first_frame.png",
"./examples/flf2v_input_last_frame.png"
]
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
en_prompt = (
"Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
"architectural structures, combining to create a tranquil and breathtaking coastal landscape."
)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
dashscope_prompt_expander = DashScopePromptExpander(
model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope result -> zh", dashscope_result.prompt)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
dashscope_prompt_expander = DashScopePromptExpander(
model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
en_prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope en result -> zh", dashscope_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
qwen_prompt_expander = QwenPromptExpander(
model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen result -> zh", qwen_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
qwen_prompt_expander = QwenPromptExpander(
model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen en result -> zh", qwen_result.prompt)

View File

@ -1,12 +1,13 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
@ -16,9 +17,10 @@ class VaceImageProcessor(object):
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg = Image.new(
cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
@ -41,10 +43,8 @@ class VaceImageProcessor(object):
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize(
(round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS
)
img = img.resize((round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS)
assert img.width >= ow and img.height >= oh
# center crop
@ -66,7 +66,11 @@ class VaceImageProcessor(object):
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
def load_image_batch(self,
*data_key_batch,
normalize=True,
seq_len=None,
**kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
@ -85,7 +89,9 @@ class VaceImageProcessor(object):
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
@ -130,8 +136,7 @@ class VaceVideoProcessor(object):
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True
)
antialias=True)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
@ -146,7 +151,8 @@ class VaceVideoProcessor(object):
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
@ -154,11 +160,10 @@ class VaceVideoProcessor(object):
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z)
)
area_z = min(self.seq_len, self.max_area / (dh * dw),
(h // dh) * (w // dw))
of = min((int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
@ -170,26 +175,27 @@ class VaceVideoProcessor(object):
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
begin = 0. if self.zero_start else rng.uniform(
0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]
), axis=1).tolist()
frame_ids = np.argmax(
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]),
axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
crop_box, rng):
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z)
)
area_z = min(self.seq_len, self.max_area / (dh * dw),
(h // dh) * (w // dw))
of = min((len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
@ -203,27 +209,39 @@ class VaceVideoProcessor(object):
target_duration = duration
target_fps = of / target_duration
timestamps = np.linspace(0., target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1]
), axis=1).tolist()
frame_ids = np.argmax(
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1]),
axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
w, crop_box, rng)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
return self.load_video_batch(
data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self,
data_key,
data_key2,
crop_box=None,
seed=2024,
**kwargs):
return self.load_video_batch(
data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
def load_video_batch(self,
*data_key_batch,
crop_box=None,
seed=2024,
**kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
@ -235,36 +253,53 @@ class VaceVideoProcessor(object):
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = [
readers[0].get_frame_timestamp(i) for i in range(length)
]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [
reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
for reader in readers
]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
src_video[i] = torch.zeros(
(3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones(
(1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
white_canvas = torch.ones(
(3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images

View File

@ -1,32 +1,41 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc
import math
import time
import random
import types
import logging
import math
import os
import random
import sys
import time
import traceback
import types
from contextlib import contextmanager
from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
WanT2V,
WanVAE,
get_sampling_sigmas,
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor
class WanVace(WanT2V):
def __init__(
self,
config,
@ -87,12 +96,13 @@ class WanVace(WanT2V):
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -100,7 +110,8 @@ class WanVace(WanT2V):
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace,
self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
@ -114,14 +125,16 @@ class WanVace(WanT2V):
self.sample_neg_prompt = config.sample_neg_prompt
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720*1280,
max_area=720*1280,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=75600,
keep_last=True)
self.vid_proc = VaceVideoProcessor(
downsample=tuple(
[x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720 * 1280,
max_area=720 * 1280,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=75600,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
@ -138,7 +151,9 @@ class WanVace(WanT2V):
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = vae.encode(inactive)
reactive = vae.encode(reactive)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
latents = [
torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)
]
cat_latents = []
for latent, refs in zip(latents, ref_images):
@ -147,7 +162,10 @@ class WanVace(WanT2V):
ref_latent = vae.encode(refs)
else:
ref_latent = vae.encode(refs)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
ref_latent = [
torch.cat((u, torch.zeros_like(u)), dim=0)
for u in ref_latent
]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
@ -169,16 +187,17 @@ class WanVace(WanT2V):
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, vae_stride[1], width, vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.view(depth, height, vae_stride[1], width,
vae_stride[1]) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
vae_stride[1] * vae_stride[2], depth, height, width
) # 8*8, depth, height, width
mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height,
width) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
mask = F.interpolate(
mask.unsqueeze(0),
size=(new_depth, height, width),
mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
@ -190,27 +209,35 @@ class WanVace(WanT2V):
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
image_size, device):
area = image_size[0] * image_size[1]
self.vid_proc.set_area(area)
if area == 720*1280:
if area == 720 * 1280:
self.vid_proc.set_seq_len(75600)
elif area == 480*832:
elif area == 480 * 832:
self.vid_proc.set_seq_len(32760)
else:
raise NotImplementedError(f'image_size {image_size} is not supported')
raise NotImplementedError(
f'image_size {image_size} is not supported')
image_size = (image_size[1], image_size[0])
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
for i, (sub_src_video,
sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
src_video[i], src_mask[
i], _, _, _ = self.vid_proc.load_video_pair(
sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
src_mask[i] = torch.clamp(
(src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_video[i] = torch.zeros(
(3, num_frames, image_size[0], image_size[1]),
device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
@ -225,18 +252,27 @@ class WanVace(WanT2V):
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(
0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
white_canvas = torch.ones(
(3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
@ -256,8 +292,6 @@ class WanVace(WanT2V):
return vae.decode(trimed_zs)
def generate(self,
input_prompt,
input_frames,
@ -335,7 +369,8 @@ class WanVace(WanT2V):
context_null = [t.to(self.device) for t in context_null]
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
z0 = self.vace_encode_frames(
input_frames, input_ref_images, masks=input_masks)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
@ -399,9 +434,17 @@ class WanVace(WanT2V):
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
@ -433,14 +476,13 @@ class WanVace(WanT2V):
class WanVaceMP(WanVace):
def __init__(
self,
config,
checkpoint_dir,
use_usp=False,
ulysses_size=None,
ring_size=None
):
def __init__(self,
config,
checkpoint_dir,
use_usp=False,
ulysses_size=None,
ring_size=None):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.use_usp = use_usp
@ -457,7 +499,8 @@ class WanVaceMP(WanVace):
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
downsample=tuple(
[x * y for x, y in zip(config.vae_stride, config.patch_size)]),
min_area=480 * 832,
max_area=480 * 832,
min_fps=self.config.sample_fps,
@ -466,20 +509,30 @@ class WanVaceMP(WanVace):
seq_len=32760,
keep_last=True)
def dynamic_load(self):
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
return
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
gpu_infer = os.environ.get(
'LOCAL_WORLD_SIZE') or torch.cuda.device_count()
pmi_rank = int(os.environ['RANK'])
pmi_world_size = int(os.environ['WORLD_SIZE'])
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
in_q_list = [
torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)
]
out_q = torch.multiprocessing.Manager().Queue()
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
initialized_events = [
torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)
]
context = mp.spawn(
self.mp_worker,
nprocs=gpu_infer,
args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q,
initialized_events, self),
join=False)
all_initialized = False
while not all_initialized:
all_initialized = all(event.is_set() for event in initialized_events)
all_initialized = all(
event.is_set() for event in initialized_events)
if not all_initialized:
time.sleep(0.1)
print('Inference model is initialized', flush=True)
@ -495,12 +548,19 @@ class WanVaceMP(WanVace):
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, list):
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
data = [
self.transfer_data_to_cuda(subdata, device)
for subdata in data
]
elif isinstance(data, dict):
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
data = {
key: self.transfer_data_to_cuda(val, device)
for key, val in data.items()
}
return data
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
out_q, initialized_events, work_env):
try:
world_size = pmi_world_size * gpu_infer
rank = pmi_rank * gpu_infer + gpu
@ -511,19 +571,19 @@ class WanVaceMP(WanVace):
backend='nccl',
init_method='env://',
rank=rank,
world_size=world_size
)
world_size=world_size)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.ring_size or 1,
ulysses_degree=self.ulysses_size or 1
)
ulysses_degree=self.ulysses_size or 1)
num_train_timesteps = self.config.num_train_timesteps
param_dtype = self.config.param_dtype
@ -532,14 +592,17 @@ class WanVaceMP(WanVace):
text_len=self.config.text_len,
dtype=self.config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
checkpoint_path=os.path.join(self.checkpoint_dir,
self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir,
self.config.t5_tokenizer),
shard_fn=shard_fn if True else None)
text_encoder.model.to(gpu)
vae_stride = self.config.vae_stride
patch_size = self.config.patch_size
vae = WanVAE(
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
vae_pth=os.path.join(self.checkpoint_dir,
self.config.vae_checkpoint),
device=gpu)
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
@ -547,9 +610,12 @@ class WanVaceMP(WanVace):
if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -557,7 +623,8 @@ class WanVaceMP(WanVace):
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
model.forward = types.MethodType(usp_dit_forward, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace,
model)
sp_size = get_sequence_parallel_world_size()
else:
sp_size = 1
@ -577,7 +644,8 @@ class WanVaceMP(WanVace):
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
input_ref_images = self.transfer_data_to_cuda(
input_ref_images, gpu)
if n_prompt == "":
n_prompt = sample_neg_prompt
@ -589,8 +657,10 @@ class WanVaceMP(WanVace):
context_null = text_encoder([n_prompt], gpu)
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
z0 = self.vace_encode_frames(
input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(
input_masks, input_ref_images, vae_stride=vae_stride)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
@ -616,7 +686,8 @@ class WanVaceMP(WanVace):
no_sync = getattr(model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
with amp.autocast(
dtype=param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
@ -631,7 +702,8 @@ class WanVaceMP(WanVace):
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
sampling_sigmas = get_sampling_sigmas(
sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=gpu,
@ -653,14 +725,20 @@ class WanVaceMP(WanVace):
model.to(gpu)
noise_pred_cond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
0]
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
@ -673,7 +751,8 @@ class WanVaceMP(WanVace):
torch.cuda.empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(x0, input_ref_images, vae=vae)
videos = self.decode_latent(
x0, input_ref_images, vae=vae)
del noise, latents
del sample_scheduler
@ -691,8 +770,6 @@ class WanVaceMP(WanVace):
print(trace_info, flush=True)
print(e, flush=True)
def generate(self,
input_prompt,
input_frames,
@ -709,8 +786,10 @@ class WanVaceMP(WanVace):
seed=-1,
offload_model=True):
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
input_data = (input_prompt, input_frames, input_masks, input_ref_images,
size, frame_num, context_scale, shift, sample_solver,
sampling_steps, guide_scale, n_prompt, seed,
offload_model)
for in_q in self.in_q_list:
in_q.put(input_data)
value_output = self.out_q.get()