Format the code (#402)

* isort the code

* format the code

* Add yapf config file

* Remove torch cuda memory profiler
This commit is contained in:
Ang Wang 2025-05-16 12:35:38 +08:00 committed by GitHub
parent c709fcf0e7
commit 76e9427657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1052 additions and 416 deletions

393
.style.yapf Normal file
View File

@ -0,0 +1,393 @@
[style]
# Align closing bracket with visual indentation.
align_closing_bracket_with_visual_indent=False
# Allow dictionary keys to exist on multiple lines. For example:
#
# x = {
# ('this is the first element of a tuple',
# 'this is the second element of a tuple'):
# value,
# }
allow_multiline_dictionary_keys=False
# Allow lambdas to be formatted on more than one line.
allow_multiline_lambdas=False
# Allow splitting before a default / named assignment in an argument list.
allow_split_before_default_or_named_assigns=False
# Allow splits before the dictionary value.
allow_split_before_dict_value=True
# Let spacing indicate operator precedence. For example:
#
# a = 1 * 2 + 3 / 4
# b = 1 / 2 - 3 * 4
# c = (1 + 2) * (3 - 4)
# d = (1 - 2) / (3 + 4)
# e = 1 * 2 - 3
# f = 1 + 2 + 3 + 4
#
# will be formatted as follows to indicate precedence:
#
# a = 1*2 + 3/4
# b = 1/2 - 3*4
# c = (1+2) * (3-4)
# d = (1-2) / (3+4)
# e = 1*2 - 3
# f = 1 + 2 + 3 + 4
#
arithmetic_precedence_indication=False
# Number of blank lines surrounding top-level function and class
# definitions.
blank_lines_around_top_level_definition=2
# Insert a blank line before a class-level docstring.
blank_line_before_class_docstring=False
# Insert a blank line before a module docstring.
blank_line_before_module_docstring=False
# Insert a blank line before a 'def' or 'class' immediately nested
# within another 'def' or 'class'. For example:
#
# class Foo:
# # <------ this blank line
# def method():
# ...
blank_line_before_nested_class_or_def=True
# Do not split consecutive brackets. Only relevant when
# dedent_closing_brackets is set. For example:
#
# call_func_that_takes_a_dict(
# {
# 'key1': 'value1',
# 'key2': 'value2',
# }
# )
#
# would reformat to:
#
# call_func_that_takes_a_dict({
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
# The column limit.
column_limit=80
# The style for continuation alignment. Possible values are:
#
# - SPACE: Use spaces for continuation alignment. This is default behavior.
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
# cannot vertically align continuation lines with indent characters.
continuation_align_style=SPACE
# Indent width used for line continuations.
continuation_indent_width=4
# Put closing brackets on a separate line, dedented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is dedented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is dedented and on a separate line
dedent_closing_brackets=False
# Disable the heuristic which places each list element on a separate line
# if the list is comma-terminated.
disable_ending_comma_heuristic=False
# Place each dictionary entry onto its own line.
each_dict_entry_on_separate_line=True
# Require multiline dictionary even if it would normally fit on one line.
# For example:
#
# config = {
# 'key1': 'value1'
# }
force_multiline_dict=False
# The regex for an i18n comment. The presence of this comment stops
# reformatting of that line, because the comments are required to be
# next to the string they translate.
i18n_comment=#\..*
# The i18n function call names. The presence of this function stops
# reformattting on that line, because the string it has cannot be moved
# away from the i18n comment.
i18n_function_call=N_, _
# Indent blank lines.
indent_blank_lines=False
# Put closing brackets on a separate line, indented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is indented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is indented and on a separate line
indent_closing_brackets=False
# Indent the dictionary value if it cannot fit on the same line as the
# dictionary key. For example:
#
# config = {
# 'key1':
# 'value1',
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=True
# The number of columns to use for indentation.
indent_width=4
# Join short lines into one line. E.g., single line 'if' statements.
join_multiple_lines=False
# Do not include spaces around selected binary operators. For example:
#
# 1 + 2 * 3 - 4 / 5
#
# will be formatted as follows when configured with "*,/":
#
# 1 + 2*3 - 4/5
no_spaces_around_selected_binary_operators=
# Use spaces around default or named assigns.
spaces_around_default_or_named_assign=False
# Adds a space after the opening '{' and before the ending '}' dict delimiters.
#
# {1: 2}
#
# will be formatted as:
#
# { 1: 2 }
spaces_around_dict_delimiters=False
# Adds a space after the opening '[' and before the ending ']' list delimiters.
#
# [1, 2]
#
# will be formatted as:
#
# [ 1, 2 ]
spaces_around_list_delimiters=False
# Use spaces around the power operator.
spaces_around_power_operator=False
# Use spaces around the subscript / slice operator. For example:
#
# my_list[1 : 10 : 2]
spaces_around_subscript_colon=False
# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
#
# (1, 2, 3)
#
# will be formatted as:
#
# ( 1, 2, 3 )
spaces_around_tuple_delimiters=False
# The number of spaces required before a trailing comment.
# This can be a single value (representing the number of spaces
# before each trailing comment) or list of values (representing
# alignment column values; trailing comments within a block will
# be aligned to the first column value that is greater than the maximum
# line length within the block). For example:
#
# With spaces_before_comment=5:
#
# 1 + 1 # Adding values
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
#
# With spaces_before_comment=15, 20:
#
# 1 + 1 # Adding values
# two + two # More adding
#
# longer_statement # This is a longer statement
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
# short # This is a shorter statement
#
# will be formatted as:
#
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
# two + two # More adding
#
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
# short # This is a shorter statement
#
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
# short # This is a shorter statement
#
spaces_before_comment=2
# Insert a space between the ending comma and closing bracket of a list,
# etc.
space_between_ending_comma_and_closing_bracket=False
# Use spaces inside brackets, braces, and parentheses. For example:
#
# method_call( 1 )
# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
# my_set = { 1, 2, 3 }
space_inside_brackets=False
# Split before arguments
split_all_comma_separated_values=False
# Split before arguments, but do not split all subexpressions recursively
# (unless needed).
split_all_top_level_comma_separated_values=False
# Split before arguments if the argument list is terminated by a
# comma.
split_arguments_when_comma_terminated=False
# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
# rather than after.
split_before_arithmetic_operator=False
# Set to True to prefer splitting before '&', '|' or '^' rather than
# after.
split_before_bitwise_operator=False
# Split before the closing bracket if a list or dict literal doesn't fit on
# a single line.
split_before_closing_bracket=True
# Split before a dictionary or set generator (comp_for). For example, note
# the split before the 'for':
#
# foo = {
# variable: 'Hello world, have a nice day!'
# for variable in bar if variable != 42
# }
split_before_dict_set_generator=False
# Split before the '.' if we need to split a longer expression:
#
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
#
# would reformat to something like:
#
# foo = ('This is a really long string: {}, {}, {}, {}'
# .format(a, b, c, d))
split_before_dot=False
# Split after the opening paren which surrounds an expression if it doesn't
# fit on a single line.
split_before_expression_after_opening_paren=True
# If an argument / parameter list is going to be split, then split before
# the first argument.
split_before_first_argument=False
# Set to True to prefer splitting before 'and' or 'or' rather than
# after.
split_before_logical_operator=False
# Split named assignments onto individual lines.
split_before_named_assigns=True
# Set to True to split list comprehensions and generators that have
# non-trivial expressions and multiple clauses before each of these
# clauses. For example:
#
# result = [
# a_long_var + 100 for a_long_var in xrange(1000)
# if a_long_var % 10]
#
# would reformat to something like:
#
# result = [
# a_long_var + 100
# for a_long_var in xrange(1000)
# if a_long_var % 10]
split_complex_comprehension=True
# The penalty for splitting right after the opening bracket.
split_penalty_after_opening_bracket=300
# The penalty for splitting the line after a unary operator.
split_penalty_after_unary_operator=10000
# The penalty of splitting the line around the '+', '-', '*', '/', '//',
# ``%``, and '@' operators.
split_penalty_arithmetic_operator=300
# The penalty for splitting right before an if expression.
split_penalty_before_if_expr=0
# The penalty of splitting the line around the '&', '|', and '^'
# operators.
split_penalty_bitwise_operator=300
# The penalty for splitting a list comprehension or generator
# expression.
split_penalty_comprehension=2100
# The penalty for characters over the column limit.
split_penalty_excess_character=7000
# The penalty incurred by adding a line split to the unwrapped line. The
# more line splits added the higher the penalty.
split_penalty_for_added_line_split=30
# The penalty of splitting a list of "import as" names. For example:
#
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
# long_argument_2,
# long_argument_3)
#
# would reformat to something like:
#
# from a_very_long_or_indented_module_name_yada_yad import (
# long_argument_1, long_argument_2, long_argument_3)
split_penalty_import_names=0
# The penalty of splitting the line around the 'and' and 'or'
# operators.
split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: format
format:
isort generate.py gradio wan
yapf -i -r *.py generate.py gradio wan

View File

@ -1,28 +1,33 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
from datetime import datetime
import logging import logging
import os import os
import sys import sys
import warnings import warnings
from datetime import datetime
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch, random import random
import torch
import torch.distributed as dist import torch.distributed as dist
from PIL import Image from PIL import Image
import wan import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = { EXAMPLE_PROMPT = {
"t2v-1.3B": { "t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
}, },
"t2v-14B": { "t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
}, },
"t2i-14B": { "t2i-14B": {
"prompt": "一个朴素端庄的美人", "prompt": "一个朴素端庄的美人",
@ -42,12 +47,16 @@ EXAMPLE_PROMPT = {
"examples/flf2v_input_last_frame.png", "examples/flf2v_input_last_frame.png",
}, },
"vace-1.3B": { "vace-1.3B": {
"src_ref_images": 'examples/girl.png,examples/snake.png', "src_ref_images":
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" 'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}, },
"vace-14B": { "vace-14B": {
"src_ref_images": 'examples/girl.png,examples/snake.png', "src_ref_images":
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" 'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
} }
} }
@ -64,7 +73,6 @@ def _validate_args(args):
if "i2v" in args.task: if "i2v" in args.task:
args.sample_steps = 40 args.sample_steps = 40
if args.sample_shift is None: if args.sample_shift is None:
args.sample_shift = 5.0 args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]: if "i2v" in args.task and args.size in ["832*480", "480*832"]:
@ -72,7 +80,6 @@ def _validate_args(args):
elif "flf2v" in args.task or "vace" in args.task: elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16 args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks. # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None: if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81 args.frame_num = 1 if "t2i" in args.task else 81
@ -167,7 +174,8 @@ def _parse_args():
"--src_ref_images", "--src_ref_images",
type=str, type=str,
default=None, default=None,
help="The file list of the source reference images. Separated by ','. Default None.") help="The file list of the source reference images. Separated by ','. Default None."
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
type=str, type=str,
@ -209,12 +217,14 @@ def _parse_args():
"--first_frame", "--first_frame",
type=str, type=str,
default=None, default=None,
help="[first-last frame to video] The image (first frame) to generate the video from.") help="[first-last frame to video] The image (first frame) to generate the video from."
)
parser.add_argument( parser.add_argument(
"--last_frame", "--last_frame",
type=str, type=str,
default=None, default=None,
help="[first-last frame to video] The image (last frame) to generate the video from.") help="[first-last frame to video] The image (last frame) to generate the video from."
)
parser.add_argument( parser.add_argument(
"--sample_solver", "--sample_solver",
type=str, type=str,
@ -281,8 +291,10 @@ def generate(args):
if args.ulysses_size > 1 or args.ring_size > 1: if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel, from xfuser.core.distributed import (
init_distributed_environment) init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment( init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()) rank=dist.get_rank(), world_size=dist.get_world_size())
@ -295,7 +307,8 @@ def generate(args):
if args.use_prompt_extend: if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope": if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander( prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task) model_name=args.prompt_extend_model,
is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen": elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander( prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, model_name=args.prompt_extend_model,
@ -482,21 +495,22 @@ def generate(args):
sampling_steps=args.sample_steps, sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model offload_model=args.offload_model)
)
elif "vace" in args.task: elif "vace" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None) args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None) args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None) args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
"src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain': if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...") logging.info("Extending prompt ...")
if rank == 0: if rank == 0:
prompt = prompt_expander.forward(args.prompt) prompt = prompt_expander.forward(args.prompt)
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'") logging.info(
f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt] input_prompt = [prompt]
else: else:
input_prompt = [None] input_prompt = [None]
@ -517,10 +531,11 @@ def generate(args):
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video], src_video, src_mask, src_ref_images = wan_vace.prepare_source(
[args.src_mask], [args.src_video], [args.src_mask], [
[None if args.src_ref_images is None else args.src_ref_images.split(',')], None if args.src_ref_images is None else
args.frame_num, SIZE_CONFIGS[args.size], device) args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...") logging.info(f"Generating video...")
video = wan_vace.generate( video = wan_vace.generate(

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import gc import gc
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Model # Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang):
return prompt_output.prompt return prompt_output.prompt
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
guide_scale, shift_scale, seed, n_prompt): resolution, sd_steps, guide_scale, shift_scale, seed,
n_prompt):
if resolution == '------': if resolution == '------':
print( print(
'Please specify the resolution ckpt dir or specify the resolution' 'Please specify the resolution ckpt dir or specify the resolution')
)
return None return None
else: else:
@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re
offload_model=True) offload_model=True)
pass pass
else: else:
print( print('Sorry, currently only 720P is supported.')
'Sorry, currently only 720P is supported.'
)
return None return None
cache_video( cache_video(
@ -191,14 +190,17 @@ def gradio_interface():
run_p_button.click( run_p_button.click(
fn=prompt_enc, fn=prompt_enc,
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
tar_lang
],
outputs=[flf2vid_prompt]) outputs=[flf2vid_prompt])
run_flf2v_button.click( run_flf2v_button.click(
fn=flf2v_generation, fn=flf2v_generation,
inputs=[ inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
guide_scale, shift_scale, seed, n_prompt resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
], ],
outputs=[result_gallery], outputs=[result_gallery],
) )

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import gc import gc
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings
@ -11,7 +11,8 @@ import gradio as gr
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Model # Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Model # Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan.configs import WAN_CONFIGS from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Model # Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan.configs import WAN_CONFIGS from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings
@ -10,7 +10,8 @@ import gradio as gr
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# Model # Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan.configs import WAN_CONFIGS from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander

View File

@ -2,36 +2,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import argparse import argparse
import datetime
import os import os
import sys import sys
import datetime
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
import gradio as gr import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan import WanVace, WanVaceMP from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue: class FixedSizeQueue:
def __init__(self, max_size): def __init__(self, max_size):
self.max_size = max_size self.max_size = max_size
self.queue = [] self.queue = []
def add(self, item): def add(self, item):
self.queue.insert(0, item) self.queue.insert(0, item)
if len(self.queue) > self.max_size: if len(self.queue) > self.max_size:
self.queue.pop() self.queue.pop()
def get(self): def get(self):
return self.queue return self.queue
def __repr__(self): def __repr__(self):
return str(self.queue) return str(self.queue)
class VACEInference: class VACEInference:
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
def __init__(self,
cfg,
skip_load=False,
gallery_share=True,
gallery_share_limit=5):
self.cfg = cfg self.cfg = cfg
self.save_dir = cfg.save_dir self.save_dir = cfg.save_dir
self.gallery_share = gallery_share self.gallery_share = gallery_share
@ -53,9 +65,7 @@ class VACEInference:
checkpoint_dir=cfg.ckpt_dir, checkpoint_dir=cfg.ckpt_dir,
use_usp=True, use_usp=True,
ulysses_size=cfg.ulysses_size, ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size ring_size=cfg.ring_size)
)
def create_ui(self, *args, **kwargs): def create_ui(self, *args, **kwargs):
gr.Markdown(""" gr.Markdown("""
@ -80,7 +90,8 @@ class VACEInference:
with gr.Row(variant='panel', equal_height=True): with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0): with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(label='src_ref_image_1', self.src_ref_image_1 = gr.Image(
label='src_ref_image_1',
height=200, height=200,
interactive=True, interactive=True,
type='filepath', type='filepath',
@ -88,7 +99,8 @@ class VACEInference:
sources=['upload'], sources=['upload'],
elem_id="src_ref_image_1", elem_id="src_ref_image_1",
format='png') format='png')
self.src_ref_image_2 = gr.Image(label='src_ref_image_2', self.src_ref_image_2 = gr.Image(
label='src_ref_image_2',
height=200, height=200,
interactive=True, interactive=True,
type='filepath', type='filepath',
@ -96,7 +108,8 @@ class VACEInference:
sources=['upload'], sources=['upload'],
elem_id="src_ref_image_2", elem_id="src_ref_image_2",
format='png') format='png')
self.src_ref_image_3 = gr.Image(label='src_ref_image_3', self.src_ref_image_3 = gr.Image(
label='src_ref_image_3',
height=200, height=200,
interactive=True, interactive=True,
type='filepath', type='filepath',
@ -158,10 +171,8 @@ class VACEInference:
step=0.5, step=0.5,
value=5.0, value=5.0,
interactive=True) interactive=True)
self.infer_seed = gr.Slider(minimum=-1, self.infer_seed = gr.Slider(
maximum=10000000, minimum=-1, maximum=10000000, value=2025, label="Seed")
value=2025,
label="Seed")
# #
with gr.Accordion(label="Usable without source video", open=False): with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
@ -176,13 +187,9 @@ class VACEInference:
value=1280, value=1280,
interactive=True) interactive=True)
self.frame_rate = gr.Textbox( self.frame_rate = gr.Textbox(
label='frame_rate', label='frame_rate', value=16, interactive=True)
value=16,
interactive=True)
self.num_frames = gr.Textbox( self.num_frames = gr.Textbox(
label='num_frames', label='num_frames', value=81, interactive=True)
value=81,
interactive=True)
# #
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
with gr.Column(scale=5): with gr.Column(scale=5):
@ -201,14 +208,19 @@ class VACEInference:
allow_preview=True, allow_preview=True,
preview=True) preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames): src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) shift_scale, sample_steps, context_scale, guide_scale,
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if infer_seed, output_height, output_width, frame_rate,
x is not None] num_frames):
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video], output_height, output_width, frame_rate, num_frames = int(
[src_mask], output_height), int(output_width), int(frame_rate), int(num_frames)
[src_ref_images], src_ref_images = [
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
if x is not None
]
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
[src_video], [src_mask], [src_ref_images],
num_frames=num_frames, num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device) device=self.pipe.device)
@ -228,10 +240,17 @@ class VACEInference:
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4') video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) video_frames = (
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
255).cpu().numpy().astype(np.uint8)
try: try:
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1) writer = imageio.get_writer(
video_path,
fps=frame_rate,
codec='libx264',
quality=8,
macro_block_size=1)
for frame in video_frames: for frame in video_frames:
writer.append_data(frame) writer.append_data(frame)
writer.close() writer.close()
@ -246,25 +265,57 @@ class VACEInference:
return [video_path] return [video_path]
def set_callbacks(self, **kwargs): def set_callbacks(self, **kwargs):
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames] self.gen_inputs = [
self.output_gallery, self.src_video, self.src_mask,
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
self.prompt, self.negative_prompt, self.shift_scale,
self.sample_steps, self.context_scale, self.guide_scale,
self.infer_seed, self.output_height, self.output_width,
self.frame_rate, self.num_frames
]
self.gen_outputs = [self.output_gallery] self.gen_outputs = [self.output_gallery]
self.generate_button.click(self.generate, self.generate_button.click(
self.generate,
inputs=self.gen_inputs, inputs=self.gen_inputs,
outputs=self.gen_outputs, outputs=self.gen_outputs,
queue=True) queue=True)
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery]) self.refresh_button.click(
lambda x: self.gallery_share_data.get()
if self.gallery_share else x,
inputs=[self.output_gallery],
outputs=[self.output_gallery])
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n') parser = argparse.ArgumentParser(
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860) description='Argparser for VACE-WAN Demo:\n')
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0') parser.add_argument(
'--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument(
'--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None) parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",) parser.add_argument(
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.") "--mp",
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") action="store_true",
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") help="Use Multi-GPUs",
)
parser.add_argument(
"--model_name",
type=str,
default="vace-14B",
choices=list(WAN_CONFIGS.keys()),
help="The model name to run.")
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument( parser.add_argument(
"--ckpt_dir", "--ckpt_dir",
type=str, type=str,
@ -284,12 +335,15 @@ if __name__ == '__main__':
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo: with gr.Blocks() as demo:
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5) infer_gr = VACEInference(
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui() infer_gr.create_ui()
infer_gr.set_callbacks() infer_gr.set_callbacks()
allowed_paths = [args.save_dir] allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(server_name=args.server_name, demo.queue(status_update_rate=1).launch(
server_name=args.server_name,
server_port=args.server_port, server_port=args.server_port,
root_path=args.root_path, root_path=args.root_path,
allowed_paths=allowed_paths, allowed_paths=allowed_paths,
show_error=True, debug=True) show_error=True,
debug=True)

View File

@ -1,5 +1,5 @@
from . import configs, distributed, modules from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V from .image2video import WanI2V
from .text2video import WanT2V from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP from .vace import WanVace, WanVaceMP

View File

@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage from torch.distributed.utils import _free_storage
def shard_model( def shard_model(
model, model,
device_id, device_id,
@ -32,6 +33,7 @@ def shard_model(
sync_module_states=sync_module_states) sync_module_states=sync_module_states)
return model return model
def free_model(model): def free_model(model):
for m in model.modules(): for m in model.modules():
if isinstance(m, FSDP): if isinstance(m, FSDP):

View File

@ -1,9 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank, from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group) get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d from ..modules.model import sinusoidal_embedding_1d
@ -63,19 +65,13 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float() return torch.stack(output).float()
def usp_dit_forward_vace( def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings # embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c] c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([ c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
dim=1) for u in c for u in c
]) ])
# arguments # arguments

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanFLF2V:
init_on_cpu = False init_on_cpu = False
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
@ -181,8 +185,10 @@ class WanFLF2V:
""" """
first_frame_size = first_frame.size first_frame_size = first_frame.size
last_frame_size = last_frame.size last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device) first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device) self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
self.device)
F = frame_num F = frame_num
first_frame_h, first_frame_w = first_frame.shape[1:] first_frame_h, first_frame_w = first_frame.shape[1:]
@ -199,8 +205,7 @@ class WanFLF2V:
# 1. resize # 1. resize
last_frame_resize_ratio = max( last_frame_resize_ratio = max(
first_frame_size[0] / last_frame_size[0], first_frame_size[0] / last_frame_size[0],
first_frame_size[1] / last_frame_size[1] first_frame_size[1] / last_frame_size[1])
)
last_frame_size = [ last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio), round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio),
@ -216,8 +221,7 @@ class WanFLF2V:
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
noise = torch.randn( noise = torch.randn(
16, 16, (F - 1) // 4 + 1,
(F - 1) // 4 + 1,
lat_h, lat_h,
lat_w, lat_w,
dtype=torch.float32, dtype=torch.float32,
@ -226,7 +230,10 @@ class WanFLF2V:
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1:-1] = 0 msk[:, 1:-1] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
@ -247,7 +254,8 @@ class WanFLF2V:
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device) self.clip.model.to(self.device)
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]]) clip_context = self.clip.visual(
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
if offload_model: if offload_model:
self.clip.model.cpu() self.clip.model.cpu()
@ -256,15 +264,14 @@ class WanFLF2V:
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
first_frame[None].cpu(), first_frame[None].cpu(),
size=(first_frame_h, first_frame_w), size=(first_frame_h, first_frame_w),
mode='bicubic' mode='bicubic').transpose(0, 1),
).transpose(0, 1),
torch.zeros(3, F - 2, first_frame_h, first_frame_w), torch.zeros(3, F - 2, first_frame_h, first_frame_w),
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
last_frame[None].cpu(), last_frame[None].cpu(),
size=(first_frame_h, first_frame_w), size=(first_frame_h, first_frame_w),
mode='bicubic' mode='bicubic').transpose(0, 1),
).transpose(0, 1), ],
], dim=1).to(self.device) dim=1).to(self.device)
])[0] ])[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanI2V:
init_on_cpu = False init_on_cpu = False
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
@ -196,8 +200,7 @@ class WanI2V:
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
noise = torch.randn( noise = torch.randn(
16, 16, (F - 1) // 4 + 1,
(F - 1) // 4 + 1,
lat_h, lat_h,
lat_w, lat_w,
dtype=torch.float32, dtype=torch.float32,

View File

@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim)) torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v` if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) self.emb_pos = nn.Parameter(
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
def forward(self, image_embeds): def forward(self, image_embeds):
if hasattr(self, 'emb_pos'): if hasattr(self, 'emb_pos'):

View File

@ -3,12 +3,13 @@ import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock): class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self, def __init__(self,
cross_attn_type, cross_attn_type,
dim, dim,
ffn_dim, ffn_dim,
@ -17,9 +18,9 @@ class VaceWanAttentionBlock(WanAttentionBlock):
qk_norm=True, qk_norm=True,
cross_attn_norm=False, cross_attn_norm=False,
eps=1e-6, eps=1e-6,
block_id=0 block_id=0):
): super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) qk_norm, cross_attn_norm, eps)
self.block_id = block_id self.block_id = block_id
if block_id == 0: if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim) self.before_proj = nn.Linear(self.dim, self.dim)
@ -39,8 +40,8 @@ class VaceWanAttentionBlock(WanAttentionBlock):
class BaseWanAttentionBlock(WanAttentionBlock): class BaseWanAttentionBlock(WanAttentionBlock):
def __init__(
self, def __init__(self,
cross_attn_type, cross_attn_type,
dim, dim,
ffn_dim, ffn_dim,
@ -49,9 +50,9 @@ class BaseWanAttentionBlock(WanAttentionBlock):
qk_norm=True, qk_norm=True,
cross_attn_norm=False, cross_attn_norm=False,
eps=1e-6, eps=1e-6,
block_id=None block_id=None):
): super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) qk_norm, cross_attn_norm, eps)
self.block_id = block_id self.block_id = block_id
def forward(self, x, hints, context_scale=1.0, **kwargs): def forward(self, x, hints, context_scale=1.0, **kwargs):
@ -62,6 +63,7 @@ class BaseWanAttentionBlock(WanAttentionBlock):
class VaceWanModel(WanModel): class VaceWanModel(WanModel):
@register_to_config @register_to_config
def __init__(self, def __init__(self,
vace_layers=None, vace_layers=None,
@ -81,42 +83,57 @@ class VaceWanModel(WanModel):
qk_norm=True, qk_norm=True,
cross_attn_norm=True, cross_attn_norm=True,
eps=1e-6): eps=1e-6):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) freq_dim, text_dim, out_dim, num_heads, num_layers,
window_size, qk_norm, cross_attn_norm, eps)
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers self.vace_layers = [i for i in range(0, self.num_layers, 2)
] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} self.vace_layers_mapping = {
i: n for n, i in enumerate(self.vace_layers)
}
# blocks # blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, BaseWanAttentionBlock(
self.cross_attn_norm, self.eps, 't2v_cross_attn',
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=self.vace_layers_mapping[i]
if i in self.vace_layers else None)
for i in range(self.num_layers) for i in range(self.num_layers)
]) ])
# vace blocks # vace blocks
self.vace_blocks = nn.ModuleList([ self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, VaceWanAttentionBlock(
self.cross_attn_norm, self.eps, block_id=i) 't2v_cross_attn',
for i in self.vace_layers self.dim,
self.ffn_dim,
self.num_heads,
self.window_size,
self.qk_norm,
self.cross_attn_norm,
self.eps,
block_id=i) for i in self.vace_layers
]) ])
# vace patch embeddings # vace patch embeddings
self.vace_patch_embedding = nn.Conv3d( self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size self.vace_in_dim,
) self.dim,
kernel_size=self.patch_size,
stride=self.patch_size)
def forward_vace( def forward_vace(self, x, vace_context, seq_len, kwargs):
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings # embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c] c = [u.flatten(2).transpose(1, 2) for u in c]

View File

@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -85,11 +88,12 @@ class WanT2V:
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)

View File

@ -1,5 +1,8 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, from .fm_solvers import (
retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor from .vace_processor import VaceVideoProcessor

View File

@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin, SchedulerMixin,
SchedulerOutput) SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor

View File

@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin, SchedulerMixin,
SchedulerOutput) SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available(): if is_scipy_available():

View File

@ -7,7 +7,7 @@ import sys
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Union, List from typing import List, Optional, Union
import dashscope import dashscope
import torch import torch
@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \ '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''Directly output the rewritten English text.''' '''Directly output the rewritten English text.'''
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师旨在参考用户输入的图像的细节内容把用户输入的Prompt改写为优质Prompt使其更完整、更具表现力同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写严格参考示例的格式进行改写 VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师旨在参考用户输入的图像的细节内容把用户输入的Prompt改写为优质Prompt使其更完整、更具表现力同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写严格参考示例的格式进行改写
任务要求 任务要求
1. 用户会输入两张图片第一张是视频的第一帧第二张时视频的最后一帧你需要综合两个照片的内容进行优化改写 1. 用户会输入两张图片第一张是视频的第一帧第二张时视频的最后一帧你需要综合两个照片的内容进行优化改写
@ -198,8 +197,8 @@ class PromptExpander:
if system_prompt is None: if system_prompt is None:
system_prompt = self.decide_system_prompt( system_prompt = self.decide_system_prompt(
tar_lang=tar_lang, tar_lang=tar_lang,
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1 multi_images_input=isinstance(image, (list, tuple)) and
) len(image) > 1)
if seed < 0: if seed < 0:
seed = random.randint(0, sys.maxsize) seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl: if image is not None and self.is_vl:
@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander):
def extend_with_img(self, def extend_with_img(self,
prompt, prompt,
system_prompt, system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None, image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1, seed=-1,
*args, *args,
**kwargs): **kwargs):
@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander):
_image.save(f.name) _image.save(f.name)
image_path = f"file://{f.name}" image_path = f"file://{f.name}"
return image_path return image_path
if not isinstance(image, (list, tuple)): if not isinstance(image, (list, tuple)):
image = [image] image = [image]
image_path_list = [ensure_image(_image) for _image in image] image_path_list = [ensure_image(_image) for _image in image]
role_content = [ role_content = [{
{"text": prompt}, "text": prompt
*[{"image": image_path} for image_path in image_path_list] }, *[{
] "image": image_path
} for image_path in image_path_list]]
system_content = [{"text": system_prompt}] system_content = [{"text": system_prompt}]
prompt = f"{prompt}" prompt = f"{prompt}"
messages = [ messages = [
@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander):
if self.is_vl: if self.is_vl:
# default: Load the model on the available device(s) # default: Load the model on the available device(s)
from transformers import (AutoProcessor, AutoTokenizer, from transformers import (
Qwen2_5_VLForConditionalGeneration) AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try: try:
from .qwen_vl_utils import process_vision_info from .qwen_vl_utils import process_vision_info
except: except:
@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander):
def extend_with_img(self, def extend_with_img(self,
prompt, prompt,
system_prompt, system_prompt,
image: Union[List[Image.Image], List[str], Image.Image, str] = None, image: Union[List[Image.Image], List[str], Image.Image,
str] = None,
seed=-1, seed=-1,
*args, *args,
**kwargs): **kwargs):
@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander):
if not isinstance(image, (list, tuple)): if not isinstance(image, (list, tuple)):
image = [image] image = [image]
system_content = [{ system_content = [{"type": "text", "text": system_prompt}]
"type": "text", role_content = [{
"text": system_prompt
}]
role_content = [
{
"type": "text", "type": "text",
"text": prompt "text": prompt
}, }, *[{
*[ "image": image_path
{"image": image_path} for image_path in image } for image_path in image]]
]
]
messages = [{ messages = [{
'role': 'system', 'role': 'system',
'content': system_content, 'content': system_content,
}, { }, {
"role": "role": "user",
"user",
"content": role_content, "content": role_content,
}] }]
@ -611,25 +610,38 @@ if __name__ == "__main__":
print("VL qwen vl en result -> en", print("VL qwen vl en result -> en",
qwen_result.prompt) # , qwen_result.system_prompt) qwen_result.prompt) # , qwen_result.system_prompt)
# test multi images # test multi images
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"] image = [
"./examples/flf2v_input_first_frame.png",
"./examples/flf2v_input_last_frame.png"
]
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。" prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic " en_prompt = (
"Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts " "aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced " "resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.") "architectural structures, combining to create a tranquil and breathtaking coastal landscape."
)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) dashscope_prompt_expander = DashScopePromptExpander(
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope result -> zh", dashscope_result.prompt) print("VL dashscope result -> zh", dashscope_result.prompt)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) dashscope_prompt_expander = DashScopePromptExpander(
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed) model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(
en_prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope en result -> zh", dashscope_result.prompt) print("VL dashscope en result -> zh", dashscope_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) qwen_prompt_expander = QwenPromptExpander(
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen result -> zh", qwen_result.prompt) print("VL qwen result -> zh", qwen_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) qwen_prompt_expander = QwenPromptExpander(
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(
prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen en result -> zh", qwen_result.prompt) print("VL qwen en result -> zh", qwen_result.prompt)

View File

@ -1,12 +1,13 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np import numpy as np
from PIL import Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image
class VaceImageProcessor(object): class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None): def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample self.downsample = downsample
self.seq_len = seq_len self.seq_len = seq_len
@ -16,7 +17,8 @@ class VaceImageProcessor(object):
if image.mode == 'P': if image.mode == 'P':
image = image.convert(f'{cvt_type}A') image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A': if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type, bg = Image.new(
cvt_type,
size=(image.width, image.height), size=(image.width, image.height),
color=(255, 255, 255)) color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image) bg.paste(image, (0, 0), mask=image)
@ -41,10 +43,8 @@ class VaceImageProcessor(object):
if iw != ow or ih != oh: if iw != ow or ih != oh:
# resize # resize
scale = max(ow / iw, oh / ih) scale = max(ow / iw, oh / ih)
img = img.resize( img = img.resize((round(scale * iw), round(scale * ih)),
(round(scale * iw), round(scale * ih)), resample=Image.Resampling.LANCZOS)
resample=Image.Resampling.LANCZOS
)
assert img.width >= ow and img.height >= oh assert img.width >= ow and img.height >= oh
# center crop # center crop
@ -66,7 +66,11 @@ class VaceImageProcessor(object):
def load_image_pair(self, data_key, data_key2, **kwargs): def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs) return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): def load_image_batch(self,
*data_key_batch,
normalize=True,
seq_len=None,
**kwargs):
seq_len = self.seq_len if seq_len is None else seq_len seq_len = self.seq_len if seq_len is None else seq_len
imgs = [] imgs = []
for data_key in data_key_batch: for data_key in data_key_batch:
@ -85,7 +89,9 @@ class VaceImageProcessor(object):
class VaceVideoProcessor(object): class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample self.downsample = downsample
self.min_area = min_area self.min_area = min_area
self.max_area = max_area self.max_area = max_area
@ -130,8 +136,7 @@ class VaceVideoProcessor(object):
video, video,
size=(round(scale * ih), round(scale * iw)), size=(round(scale * ih), round(scale * iw)),
mode='bicubic', mode='bicubic',
antialias=True antialias=True)
)
assert video.size(3) >= ow and video.size(2) >= oh assert video.size(3) >= ow and video.size(2) >= oh
# center crop # center crop
@ -146,7 +151,8 @@ class VaceVideoProcessor(object):
def _video_preprocess(self, video, oh, ow): def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow) return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
rng):
target_fps = min(fps, self.max_fps) target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean() duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
@ -154,11 +160,10 @@ class VaceVideoProcessor(object):
ratio = h / w ratio = h / w
df, dh, dw = self.downsample df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) area_z = min(self.seq_len, self.max_area / (dh * dw),
of = min( (h // dh) * (w // dw))
(int(duration * target_fps) - 1) // df + 1, of = min((int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z) int(self.seq_len / area_z))
)
# deduce target shape of the [latent video] # deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of)) target_area_z = min(area_z, int(self.seq_len / of))
@ -170,26 +175,27 @@ class VaceVideoProcessor(object):
# sample frame ids # sample frame ids
target_duration = of / target_fps target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) begin = 0. if self.zero_start else rng.uniform(
0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of) timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and( frame_ids = np.argmax(
timestamps[:, None] >= frame_timestamps[None, :, 0], np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1] timestamps[:, None] < frame_timestamps[None, :, 1]),
), axis=1).tolist() axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
crop_box, rng):
duration = frame_timestamps[-1].mean() duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1 h, w = y2 - y1, x2 - x1
ratio = h / w ratio = h / w
df, dh, dw = self.downsample df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) area_z = min(self.seq_len, self.max_area / (dh * dw),
of = min( (h // dh) * (w // dw))
(len(frame_timestamps) - 1) // df + 1, of = min((len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z) int(self.seq_len / area_z))
)
# deduce target shape of the [latent video] # deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of)) target_area_z = min(area_z, int(self.seq_len / of))
@ -203,27 +209,39 @@ class VaceVideoProcessor(object):
target_duration = duration target_duration = duration
target_fps = of / target_duration target_fps = of / target_duration
timestamps = np.linspace(0., target_duration, of) timestamps = np.linspace(0., target_duration, of)
frame_ids = np.argmax(np.logical_and( frame_ids = np.argmax(
timestamps[:, None] >= frame_timestamps[None, :, 0], np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1] timestamps[:, None] <= frame_timestamps[None, :, 1]),
), axis=1).tolist() axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last: if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
w, crop_box, rng)
else: else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) return self.load_video_batch(
data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): def load_video_pair(self,
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) data_key,
data_key2,
crop_box=None,
seed=2024,
**kwargs):
return self.load_video_batch(
data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): def load_video_batch(self,
*data_key_batch,
crop_box=None,
seed=2024,
**kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video # read video
import decord import decord
@ -235,36 +253,53 @@ class VaceVideoProcessor(object):
fps = readers[0].get_avg_fps() fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers]) length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] frame_timestamps = [
readers[0].get_frame_timestamp(i) for i in range(length)
]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32) frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2] h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video # preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] videos = [
reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
for reader in readers
]
videos = [self._video_preprocess(video, oh, ow) for video in videos] videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0] # return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None: if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) src_video[i] = torch.zeros(
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) (3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones(
(1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images): for i, ref_images in enumerate(src_ref_images):
if ref_images is not None: if ref_images is not None:
for j, ref_img in enumerate(ref_images): for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size: if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:] ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] white_canvas = torch.ones(
scale = min(canvas_height / ref_height, canvas_width / ref_width) (3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale) new_height = int(ref_height * scale)
new_width = int(ref_width * scale) new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2 top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2 left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images

View File

@ -1,32 +1,41 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc import gc
import math
import time
import random
import types
import logging import logging
import math
import os
import random
import sys
import time
import traceback import traceback
import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch import torch
import torch.nn.functional as F
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
WanT2V,
WanVAE,
get_sampling_sigmas,
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor from .utils.vace_processor import VaceVideoProcessor
class WanVace(WanT2V): class WanVace(WanT2V):
def __init__( def __init__(
self, self,
config, config,
@ -87,12 +96,13 @@ class WanVace(WanT2V):
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward, usp_dit_forward,
usp_dit_forward_vace) usp_dit_forward_vace,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
@ -100,7 +110,8 @@ class WanVace(WanT2V):
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model) self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) self.model.forward_vace = types.MethodType(usp_dit_forward_vace,
self.model)
self.sp_size = get_sequence_parallel_world_size() self.sp_size = get_sequence_parallel_world_size()
else: else:
self.sp_size = 1 self.sp_size = 1
@ -114,7 +125,9 @@ class WanVace(WanT2V):
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), self.vid_proc = VaceVideoProcessor(
downsample=tuple(
[x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720 * 1280, min_area=720 * 1280,
max_area=720 * 1280, max_area=720 * 1280,
min_fps=config.sample_fps, min_fps=config.sample_fps,
@ -138,7 +151,9 @@ class WanVace(WanT2V):
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = vae.encode(inactive) inactive = vae.encode(inactive)
reactive = vae.encode(reactive) reactive = vae.encode(reactive)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] latents = [
torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)
]
cat_latents = [] cat_latents = []
for latent, refs in zip(latents, ref_images): for latent, refs in zip(latents, ref_images):
@ -147,7 +162,10 @@ class WanVace(WanT2V):
ref_latent = vae.encode(refs) ref_latent = vae.encode(refs)
else: else:
ref_latent = vae.encode(refs) ref_latent = vae.encode(refs)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] ref_latent = [
torch.cat((u, torch.zeros_like(u)), dim=0)
for u in ref_latent
]
assert all([x.shape[1] == 1 for x in ref_latent]) assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1) latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent) cat_latents.append(latent)
@ -169,16 +187,17 @@ class WanVace(WanT2V):
# reshape # reshape
mask = mask[0, :, :, :] mask = mask[0, :, :, :]
mask = mask.view( mask = mask.view(depth, height, vae_stride[1], width,
depth, height, vae_stride[1], width, vae_stride[1] vae_stride[1]) # depth, height, 8, width, 8
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape( mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height,
vae_stride[1] * vae_stride[2], depth, height, width width) # 8*8, depth, height, width
) # 8*8, depth, height, width
# interpolation # interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) mask = F.interpolate(
mask.unsqueeze(0),
size=(new_depth, height, width),
mode='nearest-exact').squeeze(0)
if refs is not None: if refs is not None:
length = len(refs) length = len(refs)
@ -190,7 +209,8 @@ class WanVace(WanT2V):
def vace_latent(self, z, m): def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
image_size, device):
area = image_size[0] * image_size[1] area = image_size[0] * image_size[1]
self.vid_proc.set_area(area) self.vid_proc.set_area(area)
if area == 720 * 1280: if area == 720 * 1280:
@ -198,19 +218,26 @@ class WanVace(WanT2V):
elif area == 480 * 832: elif area == 480 * 832:
self.vid_proc.set_seq_len(32760) self.vid_proc.set_seq_len(32760)
else: else:
raise NotImplementedError(f'image_size {image_size} is not supported') raise NotImplementedError(
f'image_size {image_size} is not supported')
image_size = (image_size[1], image_size[0]) image_size = (image_size[1], image_size[0])
image_sizes = [] image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): for i, (sub_src_video,
sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None: if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) src_video[i], src_mask[
i], _, _, _ = self.vid_proc.load_video_pair(
sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device) src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device) src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) src_mask[i] = torch.clamp(
(src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:]) image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None: elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) src_video[i] = torch.zeros(
(3, num_frames, image_size[0], image_size[1]),
device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device) src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size) image_sizes.append(image_size)
else: else:
@ -225,18 +252,27 @@ class WanVace(WanT2V):
for j, ref_img in enumerate(ref_images): for j, ref_img in enumerate(ref_images):
if ref_img is not None: if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB") ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(
0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size: if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:] ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] white_canvas = torch.ones(
scale = min(canvas_height / ref_height, canvas_width / ref_width) (3, 1, canvas_height, canvas_width),
device=device) # [-1, 1]
scale = min(canvas_height / ref_height,
canvas_width / ref_width)
new_height = int(ref_height * scale) new_height = int(ref_height * scale)
new_width = int(ref_width * scale) new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) resized_image = F.interpolate(
ref_img.squeeze(1).unsqueeze(0),
size=(new_height, new_width),
mode='bilinear',
align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2 top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2 left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image white_canvas[:, :, top:top + new_height,
left:left + new_width] = resized_image
ref_img = white_canvas ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device) src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images
@ -256,8 +292,6 @@ class WanVace(WanT2V):
return vae.decode(trimed_zs) return vae.decode(trimed_zs)
def generate(self, def generate(self,
input_prompt, input_prompt,
input_frames, input_frames,
@ -335,7 +369,8 @@ class WanVace(WanT2V):
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
# vace context encode # vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) z0 = self.vace_encode_frames(
input_frames, input_ref_images, masks=input_masks)
m0 = self.vace_encode_masks(input_masks, input_ref_images) m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0) z = self.vace_latent(z0, m0)
@ -399,9 +434,17 @@ class WanVace(WanT2V):
self.model.to(self.device) self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0] latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0] latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * ( noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond) noise_pred_cond - noise_pred_uncond)
@ -433,14 +476,13 @@ class WanVace(WanT2V):
class WanVaceMP(WanVace): class WanVaceMP(WanVace):
def __init__(
self, def __init__(self,
config, config,
checkpoint_dir, checkpoint_dir,
use_usp=False, use_usp=False,
ulysses_size=None, ulysses_size=None,
ring_size=None ring_size=None):
):
self.config = config self.config = config
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.use_usp = use_usp self.use_usp = use_usp
@ -457,7 +499,8 @@ class WanVaceMP(WanVace):
self.device = 'cpu' if torch.cuda.is_available() else 'cpu' self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
self.vid_proc = VaceVideoProcessor( self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]), downsample=tuple(
[x * y for x, y in zip(config.vae_stride, config.patch_size)]),
min_area=480 * 832, min_area=480 * 832,
max_area=480 * 832, max_area=480 * 832,
min_fps=self.config.sample_fps, min_fps=self.config.sample_fps,
@ -466,20 +509,30 @@ class WanVaceMP(WanVace):
seq_len=32760, seq_len=32760,
keep_last=True) keep_last=True)
def dynamic_load(self): def dynamic_load(self):
if hasattr(self, 'inference_pids') and self.inference_pids is not None: if hasattr(self, 'inference_pids') and self.inference_pids is not None:
return return
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count() gpu_infer = os.environ.get(
'LOCAL_WORLD_SIZE') or torch.cuda.device_count()
pmi_rank = int(os.environ['RANK']) pmi_rank = int(os.environ['RANK'])
pmi_world_size = int(os.environ['WORLD_SIZE']) pmi_world_size = int(os.environ['WORLD_SIZE'])
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)] in_q_list = [
torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)
]
out_q = torch.multiprocessing.Manager().Queue() out_q = torch.multiprocessing.Manager().Queue()
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)] initialized_events = [
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False) torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)
]
context = mp.spawn(
self.mp_worker,
nprocs=gpu_infer,
args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q,
initialized_events, self),
join=False)
all_initialized = False all_initialized = False
while not all_initialized: while not all_initialized:
all_initialized = all(event.is_set() for event in initialized_events) all_initialized = all(
event.is_set() for event in initialized_events)
if not all_initialized: if not all_initialized:
time.sleep(0.1) time.sleep(0.1)
print('Inference model is initialized', flush=True) print('Inference model is initialized', flush=True)
@ -495,12 +548,19 @@ class WanVaceMP(WanVace):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
data = data.to(device) data = data.to(device)
elif isinstance(data, list): elif isinstance(data, list):
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data] data = [
self.transfer_data_to_cuda(subdata, device)
for subdata in data
]
elif isinstance(data, dict): elif isinstance(data, dict):
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()} data = {
key: self.transfer_data_to_cuda(val, device)
for key, val in data.items()
}
return data return data
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env): def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
out_q, initialized_events, work_env):
try: try:
world_size = pmi_world_size * gpu_infer world_size = pmi_world_size * gpu_infer
rank = pmi_rank * gpu_infer + gpu rank = pmi_rank * gpu_infer + gpu
@ -511,19 +571,19 @@ class WanVaceMP(WanVace):
backend='nccl', backend='nccl',
init_method='env://', init_method='env://',
rank=rank, rank=rank,
world_size=world_size world_size=world_size)
)
from xfuser.core.distributed import (initialize_model_parallel, from xfuser.core.distributed import (
init_distributed_environment) init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment( init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()) rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel( initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(), sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.ring_size or 1, ring_degree=self.ring_size or 1,
ulysses_degree=self.ulysses_size or 1 ulysses_degree=self.ulysses_size or 1)
)
num_train_timesteps = self.config.num_train_timesteps num_train_timesteps = self.config.num_train_timesteps
param_dtype = self.config.param_dtype param_dtype = self.config.param_dtype
@ -532,14 +592,17 @@ class WanVaceMP(WanVace):
text_len=self.config.text_len, text_len=self.config.text_len,
dtype=self.config.t5_dtype, dtype=self.config.t5_dtype,
device=torch.device('cpu'), device=torch.device('cpu'),
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint), checkpoint_path=os.path.join(self.checkpoint_dir,
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer), self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir,
self.config.t5_tokenizer),
shard_fn=shard_fn if True else None) shard_fn=shard_fn if True else None)
text_encoder.model.to(gpu) text_encoder.model.to(gpu)
vae_stride = self.config.vae_stride vae_stride = self.config.vae_stride
patch_size = self.config.patch_size patch_size = self.config.patch_size
vae = WanVAE( vae = WanVAE(
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint), vae_pth=os.path.join(self.checkpoint_dir,
self.config.vae_checkpoint),
device=gpu) device=gpu)
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
model = VaceWanModel.from_pretrained(self.checkpoint_dir) model = VaceWanModel.from_pretrained(self.checkpoint_dir)
@ -547,9 +610,12 @@ class WanVaceMP(WanVace):
if self.use_usp: if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward, usp_dit_forward,
usp_dit_forward_vace) usp_dit_forward_vace,
)
for block in model.blocks: for block in model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
@ -557,7 +623,8 @@ class WanVaceMP(WanVace):
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
model.forward = types.MethodType(usp_dit_forward, model) model.forward = types.MethodType(usp_dit_forward, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace, model) model.forward_vace = types.MethodType(usp_dit_forward_vace,
model)
sp_size = get_sequence_parallel_world_size() sp_size = get_sequence_parallel_world_size()
else: else:
sp_size = 1 sp_size = 1
@ -577,7 +644,8 @@ class WanVaceMP(WanVace):
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
input_frames = self.transfer_data_to_cuda(input_frames, gpu) input_frames = self.transfer_data_to_cuda(input_frames, gpu)
input_masks = self.transfer_data_to_cuda(input_masks, gpu) input_masks = self.transfer_data_to_cuda(input_masks, gpu)
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu) input_ref_images = self.transfer_data_to_cuda(
input_ref_images, gpu)
if n_prompt == "": if n_prompt == "":
n_prompt = sample_neg_prompt n_prompt = sample_neg_prompt
@ -589,8 +657,10 @@ class WanVaceMP(WanVace):
context_null = text_encoder([n_prompt], gpu) context_null = text_encoder([n_prompt], gpu)
# vace context encode # vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae) z0 = self.vace_encode_frames(
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride) input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(
input_masks, input_ref_images, vae_stride=vae_stride)
z = self.vace_latent(z0, m0) z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape) target_shape = list(z0[0].shape)
@ -616,7 +686,8 @@ class WanVaceMP(WanVace):
no_sync = getattr(model, 'no_sync', noop_no_sync) no_sync = getattr(model, 'no_sync', noop_no_sync)
# evaluation mode # evaluation mode
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync(): with amp.autocast(
dtype=param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc': if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler( sample_scheduler = FlowUniPCMultistepScheduler(
@ -631,7 +702,8 @@ class WanVaceMP(WanVace):
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
shift=1, shift=1,
use_dynamic_shifting=False) use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) sampling_sigmas = get_sampling_sigmas(
sampling_steps, shift)
timesteps, _ = retrieve_timesteps( timesteps, _ = retrieve_timesteps(
sample_scheduler, sample_scheduler,
device=gpu, device=gpu,
@ -653,10 +725,16 @@ class WanVaceMP(WanVace):
model.to(gpu) model.to(gpu)
noise_pred_cond = model( noise_pred_cond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[ latent_model_input,
0] t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_c)[0]
noise_pred_uncond = model( noise_pred_uncond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, latent_model_input,
t=timestep,
vace_context=z,
vace_context_scale=context_scale,
**arg_null)[0] **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * ( noise_pred = noise_pred_uncond + guide_scale * (
@ -673,7 +751,8 @@ class WanVaceMP(WanVace):
torch.cuda.empty_cache() torch.cuda.empty_cache()
x0 = latents x0 = latents
if rank == 0: if rank == 0:
videos = self.decode_latent(x0, input_ref_images, vae=vae) videos = self.decode_latent(
x0, input_ref_images, vae=vae)
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler
@ -691,8 +770,6 @@ class WanVaceMP(WanVace):
print(trace_info, flush=True) print(trace_info, flush=True)
print(e, flush=True) print(e, flush=True)
def generate(self, def generate(self,
input_prompt, input_prompt,
input_frames, input_frames,
@ -709,8 +786,10 @@ class WanVaceMP(WanVace):
seed=-1, seed=-1,
offload_model=True): offload_model=True):
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, input_data = (input_prompt, input_frames, input_masks, input_ref_images,
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model) size, frame_num, context_scale, shift, sample_solver,
sampling_steps, guide_scale, n_prompt, seed,
offload_model)
for in_q in self.in_q_list: for in_q in self.in_q_list:
in_q.put(input_data) in_q.put(input_data)
value_output = self.out_q.get() value_output = self.out_q.get()