diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..75b65d9 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,393 @@ +[style] +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent=False + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys=False + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas=False + +# Allow splitting before a default / named assignment in an argument list. +allow_split_before_default_or_named_assigns=False + +# Allow splits before the dictionary value. +allow_split_before_dict_value=True + +# Let spacing indicate operator precedence. For example: +# +# a = 1 * 2 + 3 / 4 +# b = 1 / 2 - 3 * 4 +# c = (1 + 2) * (3 - 4) +# d = (1 - 2) / (3 + 4) +# e = 1 * 2 - 3 +# f = 1 + 2 + 3 + 4 +# +# will be formatted as follows to indicate precedence: +# +# a = 1*2 + 3/4 +# b = 1/2 - 3*4 +# c = (1+2) * (3-4) +# d = (1-2) / (3+4) +# e = 1*2 - 3 +# f = 1 + 2 + 3 + 4 +# +arithmetic_precedence_indication=False + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition=2 + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring=False + +# Insert a blank line before a module docstring. +blank_line_before_module_docstring=False + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def=True + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets=False + +# The column limit. +column_limit=80 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or +# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. +# - VALIGN-RIGHT: Vertically align continuation lines to multiple of +# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if +# cannot vertically align continuation lines with indent characters. +continuation_align_style=SPACE + +# Indent width used for line continuations. +continuation_indent_width=4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets=False + +# Disable the heuristic which places each list element on a separate line +# if the list is comma-terminated. +disable_ending_comma_heuristic=False + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line=True + +# Require multiline dictionary even if it would normally fit on one line. +# For example: +# +# config = { +# 'key1': 'value1' +# } +force_multiline_dict=False + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment=#\..* + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call=N_, _ + +# Indent blank lines. +indent_blank_lines=False + +# Put closing brackets on a separate line, indented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is indented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is indented and on a separate line +indent_closing_brackets=False + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value=True + +# The number of columns to use for indentation. +indent_width=4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines=False + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with "*,/": +# +# 1 + 2*3 - 4/5 +no_spaces_around_selected_binary_operators= + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign=False + +# Adds a space after the opening '{' and before the ending '}' dict delimiters. +# +# {1: 2} +# +# will be formatted as: +# +# { 1: 2 } +spaces_around_dict_delimiters=False + +# Adds a space after the opening '[' and before the ending ']' list delimiters. +# +# [1, 2] +# +# will be formatted as: +# +# [ 1, 2 ] +spaces_around_list_delimiters=False + +# Use spaces around the power operator. +spaces_around_power_operator=False + +# Use spaces around the subscript / slice operator. For example: +# +# my_list[1 : 10 : 2] +spaces_around_subscript_colon=False + +# Adds a space after the opening '(' and before the ending ')' tuple delimiters. +# +# (1, 2, 3) +# +# will be formatted as: +# +# ( 1, 2, 3 ) +spaces_around_tuple_delimiters=False + +# The number of spaces required before a trailing comment. +# This can be a single value (representing the number of spaces +# before each trailing comment) or list of values (representing +# alignment column values; trailing comments within a block will +# be aligned to the first column value that is greater than the maximum +# line length within the block). For example: +# +# With spaces_before_comment=5: +# +# 1 + 1 # Adding values +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment +# +# With spaces_before_comment=15, 20: +# +# 1 + 1 # Adding values +# two + two # More adding +# +# longer_statement # This is a longer statement +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment +# short # This is a shorter statement +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 +# two + two # More adding +# +# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length +# short # This is a shorter statement +# +spaces_before_comment=2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket=False + +# Use spaces inside brackets, braces, and parentheses. For example: +# +# method_call( 1 ) +# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] +# my_set = { 1, 2, 3 } +space_inside_brackets=False + +# Split before arguments +split_all_comma_separated_values=False + +# Split before arguments, but do not split all subexpressions recursively +# (unless needed). +split_all_top_level_comma_separated_values=False + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated=False + +# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' +# rather than after. +split_before_arithmetic_operator=False + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator=False + +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket=True + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator=False + +# Split before the '.' if we need to split a longer expression: +# +# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) +# +# would reformat to something like: +# +# foo = ('This is a really long string: {}, {}, {}, {}' +# .format(a, b, c, d)) +split_before_dot=False + +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren=True + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument=False + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator=False + +# Split named assignments onto individual lines. +split_before_named_assigns=True + +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension=True + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket=300 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator=10000 + +# The penalty of splitting the line around the '+', '-', '*', '/', '//', +# ``%``, and '@' operators. +split_penalty_arithmetic_operator=300 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr=0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator=300 + +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension=2100 + +# The penalty for characters over the column limit. +split_penalty_excess_character=7000 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split=30 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names=0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator=300 + +# Use the Tab character for indentation. +use_tabs=False \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0345128 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +.PHONY: format + +format: + isort generate.py gradio wan + yapf -i -r *.py generate.py gradio wan diff --git a/generate.py b/generate.py index 2e6b35c..c841c19 100644 --- a/generate.py +++ b/generate.py @@ -1,28 +1,33 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -from datetime import datetime import logging import os import sys import warnings +from datetime import datetime warnings.filterwarnings('ignore') -import torch, random +import random + +import torch import torch.distributed as dist from PIL import Image import wan -from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander -from wan.utils.utils import cache_video, cache_image, str2bool +from wan.utils.utils import cache_image, cache_video, str2bool + EXAMPLE_PROMPT = { "t2v-1.3B": { - "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2v-14B": { - "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2i-14B": { "prompt": "一个朴素端庄的美人", @@ -34,20 +39,24 @@ EXAMPLE_PROMPT = { "examples/i2v_input.JPG", }, "flf2v-14B": { - "prompt": - "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", - "first_frame": - "examples/flf2v_input_first_frame.png", - "last_frame": - "examples/flf2v_input_last_frame.png", + "prompt": + "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", + "first_frame": + "examples/flf2v_input_first_frame.png", + "last_frame": + "examples/flf2v_input_last_frame.png", }, "vace-1.3B": { - "src_ref_images": 'examples/girl.png,examples/snake.png', - "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + "src_ref_images": + 'examples/girl.png,examples/snake.png', + "prompt": + "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" }, "vace-14B": { - "src_ref_images": 'examples/girl.png,examples/snake.png', - "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + "src_ref_images": + 'examples/girl.png,examples/snake.png', + "prompt": + "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" } } @@ -64,7 +73,6 @@ def _validate_args(args): if "i2v" in args.task: args.sample_steps = 40 - if args.sample_shift is None: args.sample_shift = 5.0 if "i2v" in args.task and args.size in ["832*480", "480*832"]: @@ -72,7 +80,6 @@ def _validate_args(args): elif "flf2v" in args.task or "vace" in args.task: args.sample_shift = 16 - # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. if args.frame_num is None: args.frame_num = 1 if "t2i" in args.task else 81 @@ -167,7 +174,8 @@ def _parse_args(): "--src_ref_images", type=str, default=None, - help="The file list of the source reference images. Separated by ','. Default None.") + help="The file list of the source reference images. Separated by ','. Default None." + ) parser.add_argument( "--prompt", type=str, @@ -209,12 +217,14 @@ def _parse_args(): "--first_frame", type=str, default=None, - help="[first-last frame to video] The image (first frame) to generate the video from.") + help="[first-last frame to video] The image (first frame) to generate the video from." + ) parser.add_argument( "--last_frame", type=str, default=None, - help="[first-last frame to video] The image (last frame) to generate the video from.") + help="[first-last frame to video] The image (last frame) to generate the video from." + ) parser.add_argument( "--sample_solver", type=str, @@ -281,8 +291,10 @@ def generate(args): if args.ulysses_size > 1 or args.ring_size > 1: assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." - from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) + from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, + ) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) @@ -295,7 +307,8 @@ def generate(args): if args.use_prompt_extend: if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( - model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task) + model_name=args.prompt_extend_model, + is_vl="i2v" in args.task or "flf2v" in args.task) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, @@ -482,21 +495,22 @@ def generate(args): sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, - offload_model=args.offload_model - ) + offload_model=args.offload_model) elif "vace" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None) args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None) - args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None) + args.src_ref_images = EXAMPLE_PROMPT[args.task].get( + "src_ref_images", None) logging.info(f"Input prompt: {args.prompt}") if args.use_prompt_extend and args.use_prompt_extend != 'plain': logging.info("Extending prompt ...") if rank == 0: prompt = prompt_expander.forward(args.prompt) - logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'") + logging.info( + f"Prompt extended from '{args.prompt}' to '{prompt}'") input_prompt = [prompt] else: input_prompt = [None] @@ -517,10 +531,11 @@ def generate(args): t5_cpu=args.t5_cpu, ) - src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video], - [args.src_mask], - [None if args.src_ref_images is None else args.src_ref_images.split(',')], - args.frame_num, SIZE_CONFIGS[args.size], device) + src_video, src_mask, src_ref_images = wan_vace.prepare_source( + [args.src_video], [args.src_mask], [ + None if args.src_ref_images is None else + args.src_ref_images.split(',') + ], args.frame_num, SIZE_CONFIGS[args.size], device) logging.info(f"Generating video...") video = wan_vace.generate( diff --git a/gradio/fl2v_14B_singleGPU.py b/gradio/fl2v_14B_singleGPU.py index 476a136..c55ed0c 100644 --- a/gradio/fl2v_14B_singleGPU.py +++ b/gradio/fl2v_14B_singleGPU.py @@ -1,8 +1,8 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc -import os.path as osp import os +import os.path as osp import sys import warnings @@ -11,7 +11,8 @@ import gradio as gr warnings.filterwarnings('ignore') # Model -sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander @@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang): return prompt_output.prompt -def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, - guide_scale, shift_scale, seed, n_prompt): +def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, + resolution, sd_steps, guide_scale, shift_scale, seed, + n_prompt): if resolution == '------': print( - 'Please specify the resolution ckpt dir or specify the resolution' - ) + 'Please specify the resolution ckpt dir or specify the resolution') return None else: @@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re offload_model=True) pass else: - print( - 'Sorry, currently only 720P is supported.' - ) + print('Sorry, currently only 720P is supported.') return None cache_video( @@ -191,14 +190,17 @@ def gradio_interface(): run_p_button.click( fn=prompt_enc, - inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], + inputs=[ + flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, + tar_lang + ], outputs=[flf2vid_prompt]) run_flf2v_button.click( fn=flf2v_generation, inputs=[ - flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, - guide_scale, shift_scale, seed, n_prompt + flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, + resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) diff --git a/gradio/i2v_14B_singleGPU.py b/gradio/i2v_14B_singleGPU.py index 35c1e08..2e7bcf6 100644 --- a/gradio/i2v_14B_singleGPU.py +++ b/gradio/i2v_14B_singleGPU.py @@ -1,8 +1,8 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc -import os.path as osp import os +import os.path as osp import sys import warnings @@ -11,7 +11,8 @@ import gradio as gr warnings.filterwarnings('ignore') # Model -sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander diff --git a/gradio/t2i_14B_singleGPU.py b/gradio/t2i_14B_singleGPU.py index 1ccc229..e2b6d65 100644 --- a/gradio/t2i_14B_singleGPU.py +++ b/gradio/t2i_14B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings @@ -10,7 +10,8 @@ import gradio as gr warnings.filterwarnings('ignore') # Model -sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander diff --git a/gradio/t2v_1.3B_singleGPU.py b/gradio/t2v_1.3B_singleGPU.py index 987634b..31316ba 100644 --- a/gradio/t2v_1.3B_singleGPU.py +++ b/gradio/t2v_1.3B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings @@ -10,7 +10,8 @@ import gradio as gr warnings.filterwarnings('ignore') # Model -sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander diff --git a/gradio/t2v_14B_singleGPU.py b/gradio/t2v_14B_singleGPU.py index 37c11ae..8bba789 100644 --- a/gradio/t2v_14B_singleGPU.py +++ b/gradio/t2v_14B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings @@ -10,7 +10,8 @@ import gradio as gr warnings.filterwarnings('ignore') # Model -sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander diff --git a/gradio/vace.py b/gradio/vace.py index 75f780a..d3d5206 100644 --- a/gradio/vace.py +++ b/gradio/vace.py @@ -2,36 +2,48 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import argparse +import datetime import os import sys -import datetime + import imageio import numpy as np import torch + import gradio as gr -sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) +sys.path.insert( + 0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan import WanVace, WanVaceMP -from wan.configs import WAN_CONFIGS, SIZE_CONFIGS +from wan.configs import SIZE_CONFIGS, WAN_CONFIGS class FixedSizeQueue: + def __init__(self, max_size): self.max_size = max_size self.queue = [] + def add(self, item): self.queue.insert(0, item) if len(self.queue) > self.max_size: self.queue.pop() + def get(self): return self.queue + def __repr__(self): return str(self.queue) class VACEInference: - def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5): + + def __init__(self, + cfg, + skip_load=False, + gallery_share=True, + gallery_share_limit=5): self.cfg = cfg self.save_dir = cfg.save_dir self.gallery_share = gallery_share @@ -53,9 +65,7 @@ class VACEInference: checkpoint_dir=cfg.ckpt_dir, use_usp=True, ulysses_size=cfg.ulysses_size, - ring_size=cfg.ring_size - ) - + ring_size=cfg.ring_size) def create_ui(self, *args, **kwargs): gr.Markdown(""" @@ -80,30 +90,33 @@ class VACEInference: with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): with gr.Row(equal_height=True): - self.src_ref_image_1 = gr.Image(label='src_ref_image_1', - height=200, - interactive=True, - type='filepath', - image_mode='RGB', - sources=['upload'], - elem_id="src_ref_image_1", - format='png') - self.src_ref_image_2 = gr.Image(label='src_ref_image_2', - height=200, - interactive=True, - type='filepath', - image_mode='RGB', - sources=['upload'], - elem_id="src_ref_image_2", - format='png') - self.src_ref_image_3 = gr.Image(label='src_ref_image_3', - height=200, - interactive=True, - type='filepath', - image_mode='RGB', - sources=['upload'], - elem_id="src_ref_image_3", - format='png') + self.src_ref_image_1 = gr.Image( + label='src_ref_image_1', + height=200, + interactive=True, + type='filepath', + image_mode='RGB', + sources=['upload'], + elem_id="src_ref_image_1", + format='png') + self.src_ref_image_2 = gr.Image( + label='src_ref_image_2', + height=200, + interactive=True, + type='filepath', + image_mode='RGB', + sources=['upload'], + elem_id="src_ref_image_2", + format='png') + self.src_ref_image_3 = gr.Image( + label='src_ref_image_3', + height=200, + interactive=True, + type='filepath', + image_mode='RGB', + sources=['upload'], + elem_id="src_ref_image_3", + format='png') with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1): self.prompt = gr.Textbox( @@ -158,10 +171,8 @@ class VACEInference: step=0.5, value=5.0, interactive=True) - self.infer_seed = gr.Slider(minimum=-1, - maximum=10000000, - value=2025, - label="Seed") + self.infer_seed = gr.Slider( + minimum=-1, maximum=10000000, value=2025, label="Seed") # with gr.Accordion(label="Usable without source video", open=False): with gr.Row(equal_height=True): @@ -176,13 +187,9 @@ class VACEInference: value=1280, interactive=True) self.frame_rate = gr.Textbox( - label='frame_rate', - value=16, - interactive=True) + label='frame_rate', value=16, interactive=True) self.num_frames = gr.Textbox( - label='num_frames', - value=81, - interactive=True) + label='num_frames', value=81, interactive=True) # with gr.Row(equal_height=True): with gr.Column(scale=5): @@ -201,17 +208,22 @@ class VACEInference: allow_preview=True, preview=True) - - def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames): - output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) - src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if - x is not None] - src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video], - [src_mask], - [src_ref_images], - num_frames=num_frames, - image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], - device=self.pipe.device) + def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, + src_ref_image_2, src_ref_image_3, prompt, negative_prompt, + shift_scale, sample_steps, context_scale, guide_scale, + infer_seed, output_height, output_width, frame_rate, + num_frames): + output_height, output_width, frame_rate, num_frames = int( + output_height), int(output_width), int(frame_rate), int(num_frames) + src_ref_images = [ + x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] + if x is not None + ] + src_video, src_mask, src_ref_images = self.pipe.prepare_source( + [src_video], [src_mask], [src_ref_images], + num_frames=num_frames, + image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], + device=self.pipe.device) video = self.pipe.generate( prompt, src_video, @@ -228,10 +240,17 @@ class VACEInference: name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4') - video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) + video_frames = ( + torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * + 255).cpu().numpy().astype(np.uint8) try: - writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1) + writer = imageio.get_writer( + video_path, + fps=frame_rate, + codec='libx264', + quality=8, + macro_block_size=1) for frame in video_frames: writer.append_data(frame) writer.close() @@ -246,25 +265,57 @@ class VACEInference: return [video_path] def set_callbacks(self, **kwargs): - self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames] + self.gen_inputs = [ + self.output_gallery, self.src_video, self.src_mask, + self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, + self.prompt, self.negative_prompt, self.shift_scale, + self.sample_steps, self.context_scale, self.guide_scale, + self.infer_seed, self.output_height, self.output_width, + self.frame_rate, self.num_frames + ] self.gen_outputs = [self.output_gallery] - self.generate_button.click(self.generate, - inputs=self.gen_inputs, - outputs=self.gen_outputs, - queue=True) - self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery]) + self.generate_button.click( + self.generate, + inputs=self.gen_inputs, + outputs=self.gen_outputs, + queue=True) + self.refresh_button.click( + lambda x: self.gallery_share_data.get() + if self.gallery_share else x, + inputs=[self.output_gallery], + outputs=[self.output_gallery]) if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n') - parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860) - parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0') + parser = argparse.ArgumentParser( + description='Argparser for VACE-WAN Demo:\n') + parser.add_argument( + '--server_port', dest='server_port', help='', type=int, default=7860) + parser.add_argument( + '--server_name', dest='server_name', help='', default='0.0.0.0') parser.add_argument('--root_path', dest='root_path', help='', default=None) parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') - parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",) - parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.") - parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") - parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") + parser.add_argument( + "--mp", + action="store_true", + help="Use Multi-GPUs", + ) + parser.add_argument( + "--model_name", + type=str, + default="vace-14B", + choices=list(WAN_CONFIGS.keys()), + help="The model name to run.") + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--ring_size", + type=int, + default=1, + help="The size of the ring attention parallelism in DiT.") parser.add_argument( "--ckpt_dir", type=str, @@ -284,12 +335,15 @@ if __name__ == '__main__': os.makedirs(args.save_dir, exist_ok=True) with gr.Blocks() as demo: - infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5) + infer_gr = VACEInference( + args, skip_load=False, gallery_share=True, gallery_share_limit=5) infer_gr.create_ui() infer_gr.set_callbacks() allowed_paths = [args.save_dir] - demo.queue(status_update_rate=1).launch(server_name=args.server_name, - server_port=args.server_port, - root_path=args.root_path, - allowed_paths=allowed_paths, - show_error=True, debug=True) + demo.queue(status_update_rate=1).launch( + server_name=args.server_name, + server_port=args.server_port, + root_path=args.root_path, + allowed_paths=allowed_paths, + show_error=True, + debug=True) diff --git a/wan/__init__.py b/wan/__init__.py index 45d555d..afed024 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -1,5 +1,5 @@ from . import configs, distributed, modules +from .first_last_frame2video import WanFLF2V from .image2video import WanI2V from .text2video import WanT2V -from .first_last_frame2video import WanFLF2V from .vace import WanVace, WanVaceMP diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 18ba2f3..6bb496d 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage + def shard_model( model, device_id, @@ -32,6 +33,7 @@ def shard_model( sync_module_states=sync_module_states) return model + def free_model(model): for m in model.modules(): if isinstance(m, FSDP): diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index e0be6c7..4718577 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -1,9 +1,11 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp -from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) +from xfuser.core.distributed import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d @@ -63,19 +65,13 @@ def rope_apply(x, grid_sizes, freqs): return torch.stack(output).float() -def usp_dit_forward_vace( - self, - x, - vace_context, - seq_len, - kwargs -): +def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs): # embeddings c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] c = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in c + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in c ]) # arguments diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 4f300ca..232950f 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -21,8 +21,11 @@ from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -103,11 +106,12 @@ class WanFLF2V: init_on_cpu = False if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @@ -181,8 +185,10 @@ class WanFLF2V: """ first_frame_size = first_frame.size last_frame_size = last_frame.size - first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device) - last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device) + first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( + self.device) + last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to( + self.device) F = frame_num first_frame_h, first_frame_w = first_frame.shape[1:] @@ -199,8 +205,7 @@ class WanFLF2V: # 1. resize last_frame_resize_ratio = max( first_frame_size[0] / last_frame_size[0], - first_frame_size[1] / last_frame_size[1] - ) + first_frame_size[1] / last_frame_size[1]) last_frame_size = [ round(last_frame_size[0] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio), @@ -216,8 +221,7 @@ class WanFLF2V: seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( - 16, - (F - 1) // 4 + 1, + 16, (F - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, @@ -225,8 +229,11 @@ class WanFLF2V: device=self.device) msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) - msk[:, 1: -1] = 0 - msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk[:, 1:-1] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] @@ -247,7 +254,8 @@ class WanFLF2V: context_null = [t.to(self.device) for t in context_null] self.clip.model.to(self.device) - clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]]) + clip_context = self.clip.visual( + [first_frame[:, None, :, :], last_frame[:, None, :, :]]) if offload_model: self.clip.model.cpu() @@ -256,15 +264,14 @@ class WanFLF2V: torch.nn.functional.interpolate( first_frame[None].cpu(), size=(first_frame_h, first_frame_w), - mode='bicubic' - ).transpose(0, 1), + mode='bicubic').transpose(0, 1), torch.zeros(3, F - 2, first_frame_h, first_frame_w), torch.nn.functional.interpolate( last_frame[None].cpu(), size=(first_frame_h, first_frame_w), - mode='bicubic' - ).transpose(0, 1), - ], dim=1).to(self.device) + mode='bicubic').transpose(0, 1), + ], + dim=1).to(self.device) ])[0] y = torch.concat([msk, y]) diff --git a/wan/image2video.py b/wan/image2video.py index 5004f46..6882c53 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -21,8 +21,11 @@ from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -103,11 +106,12 @@ class WanI2V: init_on_cpu = False if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @@ -196,8 +200,7 @@ class WanI2V: seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( - 16, - (F - 1) // 4 + 1, + 16, (F - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, diff --git a/wan/modules/model.py b/wan/modules/model.py index b94474a..a5425da 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -273,7 +273,7 @@ class WanAttentionBlock(nn.Module): nn.Linear(ffn_dim, dim)) # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, @@ -332,7 +332,7 @@ class Head(nn.Module): self.head = nn.Linear(dim, out_dim) # modulation - self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, e): r""" @@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) if flf_pos_emb: # NOTE: we only use this for `flf2v` - self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) + self.emb_pos = nn.Parameter( + torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) def forward(self, image_embeds): if hasattr(self, 'emb_pos'): diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index 60178a9..a12d1dd 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -3,23 +3,24 @@ import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config -from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d + +from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d class VaceWanAttentionBlock(WanAttentionBlock): - def __init__( - self, - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - block_id=0 - ): - super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, + qk_norm, cross_attn_norm, eps) self.block_id = block_id if block_id == 0: self.before_proj = nn.Linear(self.dim, self.dim) @@ -39,19 +40,19 @@ class VaceWanAttentionBlock(WanAttentionBlock): class BaseWanAttentionBlock(WanAttentionBlock): - def __init__( - self, - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - block_id=None - ): - super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, + qk_norm, cross_attn_norm, eps) self.block_id = block_id def forward(self, x, hints, context_scale=1.0, **kwargs): @@ -62,6 +63,7 @@ class BaseWanAttentionBlock(WanAttentionBlock): class VaceWanModel(WanModel): + @register_to_config def __init__(self, vace_layers=None, @@ -81,42 +83,57 @@ class VaceWanModel(WanModel): qk_norm=True, cross_attn_norm=True, eps=1e-6): - super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, - num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, + freq_dim, text_dim, out_dim, num_heads, num_layers, + window_size, qk_norm, cross_attn_norm, eps) - self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_layers = [i for i in range(0, self.num_layers, 2) + ] if vace_layers is None else vace_layers self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim assert 0 in self.vace_layers - self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + self.vace_layers_mapping = { + i: n for n, i in enumerate(self.vace_layers) + } # blocks self.blocks = nn.ModuleList([ - BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, - self.cross_attn_norm, self.eps, - block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + BaseWanAttentionBlock( + 't2v_cross_attn', + self.dim, + self.ffn_dim, + self.num_heads, + self.window_size, + self.qk_norm, + self.cross_attn_norm, + self.eps, + block_id=self.vace_layers_mapping[i] + if i in self.vace_layers else None) for i in range(self.num_layers) ]) # vace blocks self.vace_blocks = nn.ModuleList([ - VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, - self.cross_attn_norm, self.eps, block_id=i) - for i in self.vace_layers + VaceWanAttentionBlock( + 't2v_cross_attn', + self.dim, + self.ffn_dim, + self.num_heads, + self.window_size, + self.qk_norm, + self.cross_attn_norm, + self.eps, + block_id=i) for i in self.vace_layers ]) # vace patch embeddings self.vace_patch_embedding = nn.Conv3d( - self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size - ) + self.vace_in_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size) - def forward_vace( - self, - x, - vace_context, - seq_len, - kwargs - ): + def forward_vace(self, x, vace_context, seq_len, kwargs): # embeddings c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] @@ -230,4 +247,4 @@ class VaceWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] \ No newline at end of file + return [u.float() for u in x] diff --git a/wan/text2video.py b/wan/text2video.py index 2400545..c518b61 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -85,11 +88,12 @@ class WanT2V: self.model.eval().requires_grad_(False) if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index ba3fe7d..2e9b33d 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -1,5 +1,8 @@ -from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, - retrieve_timesteps) +from .fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor diff --git a/wan/utils/fm_solvers.py b/wan/utils/fm_solvers.py index c908969..17bef85 100644 --- a/wan/utils/fm_solvers.py +++ b/wan/utils/fm_solvers.py @@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput) +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) from diffusers.utils import deprecate, is_scipy_available from diffusers.utils.torch_utils import randn_tensor diff --git a/wan/utils/fm_solvers_unipc.py b/wan/utils/fm_solvers_unipc.py index 57321ba..fb502f2 100644 --- a/wan/utils/fm_solvers_unipc.py +++ b/wan/utils/fm_solvers_unipc.py @@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput) +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) from diffusers.utils import deprecate, is_scipy_available if is_scipy_available(): diff --git a/wan/utils/prompt_extend.py b/wan/utils/prompt_extend.py index 5e3a216..3eda6d5 100644 --- a/wan/utils/prompt_extend.py +++ b/wan/utils/prompt_extend.py @@ -7,7 +7,7 @@ import sys import tempfile from dataclasses import dataclass from http import HTTPStatus -from typing import Optional, Union, List +from typing import List, Optional, Union import dashscope import torch @@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \ '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \ '''Directly output the rewritten English text.''' - VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写 任务要求: 1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写 @@ -198,8 +197,8 @@ class PromptExpander: if system_prompt is None: system_prompt = self.decide_system_prompt( tar_lang=tar_lang, - multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1 - ) + multi_images_input=isinstance(image, (list, tuple)) and + len(image) > 1) if seed < 0: seed = random.randint(0, sys.maxsize) if image is not None and self.is_vl: @@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander): def extend_with_img(self, prompt, system_prompt, - image: Union[List[Image.Image], List[str], Image.Image, str] = None, + image: Union[List[Image.Image], List[str], Image.Image, + str] = None, seed=-1, *args, **kwargs): @@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander): _image.save(f.name) image_path = f"file://{f.name}" return image_path + if not isinstance(image, (list, tuple)): image = [image] image_path_list = [ensure_image(_image) for _image in image] - role_content = [ - {"text": prompt}, - *[{"image": image_path} for image_path in image_path_list] - ] + role_content = [{ + "text": prompt + }, *[{ + "image": image_path + } for image_path in image_path_list]] system_content = [{"text": system_prompt}] prompt = f"{prompt}" messages = [ @@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander): if self.is_vl: # default: Load the model on the available device(s) - from transformers import (AutoProcessor, AutoTokenizer, - Qwen2_5_VLForConditionalGeneration) + from transformers import ( + AutoProcessor, + AutoTokenizer, + Qwen2_5_VLForConditionalGeneration, + ) try: from .qwen_vl_utils import process_vision_info except: @@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander): def extend_with_img(self, prompt, system_prompt, - image: Union[List[Image.Image], List[str], Image.Image, str] = None, + image: Union[List[Image.Image], List[str], Image.Image, + str] = None, seed=-1, *args, **kwargs): @@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander): if not isinstance(image, (list, tuple)): image = [image] - system_content = [{ + system_content = [{"type": "text", "text": system_prompt}] + role_content = [{ "type": "text", - "text": system_prompt - }] - role_content = [ - { - "type": "text", - "text": prompt - }, - *[ - {"image": image_path} for image_path in image - ] - ] + "text": prompt + }, *[{ + "image": image_path + } for image_path in image]] messages = [{ 'role': 'system', 'content': system_content, }, { - "role": - "user", + "role": "user", "content": role_content, }] @@ -611,25 +610,38 @@ if __name__ == "__main__": print("VL qwen vl en result -> en", qwen_result.prompt) # , qwen_result.system_prompt) # test multi images - image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"] + image = [ + "./examples/flf2v_input_first_frame.png", + "./examples/flf2v_input_last_frame.png" + ] prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。" - en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic " - "aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts " - "resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced " - "architectural structures, combining to create a tranquil and breathtaking coastal landscape.") + en_prompt = ( + "Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic " + "aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts " + "resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced " + "architectural structures, combining to create a tranquil and breathtaking coastal landscape." + ) - dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) - dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + dashscope_prompt_expander = DashScopePromptExpander( + model_name=ds_model_name, is_vl=True) + dashscope_result = dashscope_prompt_expander( + prompt, tar_lang="zh", image=image, seed=seed) print("VL dashscope result -> zh", dashscope_result.prompt) - dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) - dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed) + dashscope_prompt_expander = DashScopePromptExpander( + model_name=ds_model_name, is_vl=True) + dashscope_result = dashscope_prompt_expander( + en_prompt, tar_lang="zh", image=image, seed=seed) print("VL dashscope en result -> zh", dashscope_result.prompt) - qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) - qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + qwen_prompt_expander = QwenPromptExpander( + model_name=qwen_model_name, is_vl=True, device=0) + qwen_result = qwen_prompt_expander( + prompt, tar_lang="zh", image=image, seed=seed) print("VL qwen result -> zh", qwen_result.prompt) - qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) - qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + qwen_prompt_expander = QwenPromptExpander( + model_name=qwen_model_name, is_vl=True, device=0) + qwen_result = qwen_prompt_expander( + prompt, tar_lang="zh", image=image, seed=seed) print("VL qwen en result -> zh", qwen_result.prompt) diff --git a/wan/utils/vace_processor.py b/wan/utils/vace_processor.py index 5f7224f..5f47fd6 100644 --- a/wan/utils/vace_processor.py +++ b/wan/utils/vace_processor.py @@ -1,12 +1,13 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import numpy as np -from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms.functional as TF +from PIL import Image class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): self.downsample = downsample self.seq_len = seq_len @@ -16,9 +17,10 @@ class VaceImageProcessor(object): if image.mode == 'P': image = image.convert(f'{cvt_type}A') if image.mode == f'{cvt_type}A': - bg = Image.new(cvt_type, - size=(image.width, image.height), - color=(255, 255, 255)) + bg = Image.new( + cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) bg.paste(image, (0, 0), mask=image) image = bg else: @@ -41,10 +43,8 @@ class VaceImageProcessor(object): if iw != ow or ih != oh: # resize scale = max(ow / iw, oh / ih) - img = img.resize( - (round(scale * iw), round(scale * ih)), - resample=Image.Resampling.LANCZOS - ) + img = img.resize((round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS) assert img.width >= ow and img.height >= oh # center crop @@ -66,7 +66,11 @@ class VaceImageProcessor(object): def load_image_pair(self, data_key, data_key2, **kwargs): return self.load_image_batch(data_key, data_key2, **kwargs) - def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + def load_image_batch(self, + *data_key_batch, + normalize=True, + seq_len=None, + **kwargs): seq_len = self.seq_len if seq_len is None else seq_len imgs = [] for data_key in data_key_batch: @@ -85,7 +89,9 @@ class VaceImageProcessor(object): class VaceVideoProcessor(object): - def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, + zero_start, seq_len, keep_last, **kwargs): self.downsample = downsample self.min_area = min_area self.max_area = max_area @@ -130,8 +136,7 @@ class VaceVideoProcessor(object): video, size=(round(scale * ih), round(scale * iw)), mode='bicubic', - antialias=True - ) + antialias=True) assert video.size(3) >= ow and video.size(2) >= oh # center crop @@ -146,7 +151,8 @@ class VaceVideoProcessor(object): def _video_preprocess(self, video, oh, ow): return self.resize_crop(video, oh, ow) - def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, + rng): target_fps = min(fps, self.max_fps) duration = frame_timestamps[-1].mean() x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box @@ -154,11 +160,10 @@ class VaceVideoProcessor(object): ratio = h / w df, dh, dw = self.downsample - area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) - of = min( - (int(duration * target_fps) - 1) // df + 1, - int(self.seq_len / area_z) - ) + area_z = min(self.seq_len, self.max_area / (dh * dw), + (h // dh) * (w // dw)) + of = min((int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z)) # deduce target shape of the [latent video] target_area_z = min(area_z, int(self.seq_len / of)) @@ -170,26 +175,27 @@ class VaceVideoProcessor(object): # sample frame ids target_duration = of / target_fps - begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + begin = 0. if self.zero_start else rng.uniform( + 0, duration - target_duration) timestamps = np.linspace(begin, begin + target_duration, of) - frame_ids = np.argmax(np.logical_and( - timestamps[:, None] >= frame_timestamps[None, :, 0], - timestamps[:, None] < frame_timestamps[None, :, 1] - ), axis=1).tolist() + frame_ids = np.argmax( + np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1]), + axis=1).tolist() return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps - def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, + crop_box, rng): duration = frame_timestamps[-1].mean() x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box h, w = y2 - y1, x2 - x1 ratio = h / w df, dh, dw = self.downsample - area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) - of = min( - (len(frame_timestamps) - 1) // df + 1, - int(self.seq_len / area_z) - ) + area_z = min(self.seq_len, self.max_area / (dh * dw), + (h // dh) * (w // dw)) + of = min((len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z)) # deduce target shape of the [latent video] target_area_z = min(area_z, int(self.seq_len / of)) @@ -203,27 +209,39 @@ class VaceVideoProcessor(object): target_duration = duration target_fps = of / target_duration timestamps = np.linspace(0., target_duration, of) - frame_ids = np.argmax(np.logical_and( - timestamps[:, None] >= frame_timestamps[None, :, 0], - timestamps[:, None] <= frame_timestamps[None, :, 1] - ), axis=1).tolist() + frame_ids = np.argmax( + np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1]), + axis=1).tolist() # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps - def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): if self.keep_last: - return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, + w, crop_box, rng) else: - return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, + crop_box, rng) def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): - return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + return self.load_video_batch( + data_key, crop_box=crop_box, seed=seed, **kwargs) - def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): - return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + def load_video_pair(self, + data_key, + data_key2, + crop_box=None, + seed=2024, + **kwargs): + return self.load_video_batch( + data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) - def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + def load_video_batch(self, + *data_key_batch, + crop_box=None, + seed=2024, + **kwargs): rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) # read video import decord @@ -235,36 +253,53 @@ class VaceVideoProcessor(object): fps = readers[0].get_avg_fps() length = min([len(r) for r in readers]) - frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = [ + readers[0].get_frame_timestamp(i) for i in range(length) + ] frame_timestamps = np.array(frame_timestamps, dtype=np.float32) h, w = readers[0].next().shape[:2] - frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox( + fps, frame_timestamps, h, w, crop_box, rng) # preprocess video - videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [ + reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] + for reader in readers + ] videos = [self._video_preprocess(video, oh, ow) for video in videos] return *videos, frame_ids, (oh, ow), fps # return videos if len(videos) > 1 else videos[0] -def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, + device): for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): if sub_src_video is None and sub_src_mask is None: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + src_video[i] = torch.zeros( + (3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones( + (1, num_frames, image_size[0], image_size[1]), device=device) for i, ref_images in enumerate(src_ref_images): if ref_images is not None: for j, ref_img in enumerate(ref_images): if ref_img is not None and ref_img.shape[-2:] != image_size: canvas_height, canvas_width = image_size ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) + white_canvas = torch.ones( + (3, 1, canvas_height, canvas_width), + device=device) # [-1, 1] + scale = min(canvas_height / ref_height, + canvas_width / ref_width) new_height = int(ref_height * scale) new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + resized_image = F.interpolate( + ref_img.squeeze(1).unsqueeze(0), + size=(new_height, new_width), + mode='bilinear', + align_corners=False).squeeze(0).unsqueeze(1) top = (canvas_height - new_height) // 2 left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + white_canvas[:, :, top:top + new_height, + left:left + new_width] = resized_image src_ref_images[i][j] = white_canvas return src_video, src_mask, src_ref_images diff --git a/wan/vace.py b/wan/vace.py index d792e9b..8a4f744 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -1,32 +1,41 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import os -import sys import gc -import math -import time -import random -import types import logging +import math +import os +import random +import sys +import time import traceback +import types from contextlib import contextmanager from functools import partial -from PIL import Image -import torchvision.transforms.functional as TF import torch -import torch.nn.functional as F import torch.cuda.amp as amp import torch.distributed as dist import torch.multiprocessing as mp +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from PIL import Image from tqdm import tqdm -from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) from .modules.vace_model import VaceWanModel +from .text2video import ( + FlowDPMSolverMultistepScheduler, + FlowUniPCMultistepScheduler, + T5EncoderModel, + WanT2V, + WanVAE, + get_sampling_sigmas, + retrieve_timesteps, + shard_model, +) from .utils.vace_processor import VaceVideoProcessor class WanVace(WanT2V): + def __init__( self, config, @@ -87,12 +96,13 @@ class WanVace(WanT2V): self.model.eval().requires_grad_(False) if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward, - usp_dit_forward_vace) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @@ -100,7 +110,8 @@ class WanVace(WanT2V): block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) self.model.forward = types.MethodType(usp_dit_forward, self.model) - self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) + self.model.forward_vace = types.MethodType(usp_dit_forward_vace, + self.model) self.sp_size = get_sequence_parallel_world_size() else: self.sp_size = 1 @@ -114,14 +125,16 @@ class WanVace(WanT2V): self.sample_neg_prompt = config.sample_neg_prompt - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=720*1280, - max_area=720*1280, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=75600, - keep_last=True) + self.vid_proc = VaceVideoProcessor( + downsample=tuple( + [x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=720 * 1280, + max_area=720 * 1280, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=75600, + keep_last=True) def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): vae = self.vae if vae is None else vae @@ -138,7 +151,9 @@ class WanVace(WanT2V): reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = vae.encode(inactive) reactive = vae.encode(reactive) - latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + latents = [ + torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive) + ] cat_latents = [] for latent, refs in zip(latents, ref_images): @@ -147,7 +162,10 @@ class WanVace(WanT2V): ref_latent = vae.encode(refs) else: ref_latent = vae.encode(refs) - ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + ref_latent = [ + torch.cat((u, torch.zeros_like(u)), dim=0) + for u in ref_latent + ] assert all([x.shape[1] == 1 for x in ref_latent]) latent = torch.cat([*ref_latent, latent], dim=1) cat_latents.append(latent) @@ -169,16 +187,17 @@ class WanVace(WanT2V): # reshape mask = mask[0, :, :, :] - mask = mask.view( - depth, height, vae_stride[1], width, vae_stride[1] - ) # depth, height, 8, width, 8 + mask = mask.view(depth, height, vae_stride[1], width, + vae_stride[1]) # depth, height, 8, width, 8 mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width - mask = mask.reshape( - vae_stride[1] * vae_stride[2], depth, height, width - ) # 8*8, depth, height, width + mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height, + width) # 8*8, depth, height, width # interpolation - mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + mask = F.interpolate( + mask.unsqueeze(0), + size=(new_depth, height, width), + mode='nearest-exact').squeeze(0) if refs is not None: length = len(refs) @@ -190,27 +209,35 @@ class WanVace(WanT2V): def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, + image_size, device): area = image_size[0] * image_size[1] self.vid_proc.set_area(area) - if area == 720*1280: + if area == 720 * 1280: self.vid_proc.set_seq_len(75600) - elif area == 480*832: + elif area == 480 * 832: self.vid_proc.set_seq_len(32760) else: - raise NotImplementedError(f'image_size {image_size} is not supported') + raise NotImplementedError( + f'image_size {image_size} is not supported') image_size = (image_size[1], image_size[0]) image_sizes = [] - for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + for i, (sub_src_video, + sub_src_mask) in enumerate(zip(src_video, src_mask)): if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i], src_mask[ + i], _, _, _ = self.vid_proc.load_video_pair( + sub_src_video, sub_src_mask) src_video[i] = src_video[i].to(device) src_mask[i] = src_mask[i].to(device) - src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + src_mask[i] = torch.clamp( + (src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) image_sizes.append(src_video[i].shape[2:]) elif sub_src_video is None: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_video[i] = torch.zeros( + (3, num_frames, image_size[0], image_size[1]), + device=device) src_mask[i] = torch.ones_like(src_video[i], device=device) image_sizes.append(image_size) else: @@ -225,18 +252,27 @@ class WanVace(WanT2V): for j, ref_img in enumerate(ref_images): if ref_img is not None: ref_img = Image.open(ref_img).convert("RGB") - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_( + 0.5).unsqueeze(1) if ref_img.shape[-2:] != image_size: canvas_height, canvas_width = image_size ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) + white_canvas = torch.ones( + (3, 1, canvas_height, canvas_width), + device=device) # [-1, 1] + scale = min(canvas_height / ref_height, + canvas_width / ref_width) new_height = int(ref_height * scale) new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + resized_image = F.interpolate( + ref_img.squeeze(1).unsqueeze(0), + size=(new_height, new_width), + mode='bilinear', + align_corners=False).squeeze(0).unsqueeze(1) top = (canvas_height - new_height) // 2 left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + white_canvas[:, :, top:top + new_height, + left:left + new_width] = resized_image ref_img = white_canvas src_ref_images[i][j] = ref_img.to(device) return src_video, src_mask, src_ref_images @@ -256,8 +292,6 @@ class WanVace(WanT2V): return vae.decode(trimed_zs) - - def generate(self, input_prompt, input_frames, @@ -335,7 +369,8 @@ class WanVace(WanT2V): context_null = [t.to(self.device) for t in context_null] # vace context encode - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + z0 = self.vace_encode_frames( + input_frames, input_ref_images, masks=input_masks) m0 = self.vace_encode_masks(input_masks, input_ref_images) z = self.vace_latent(z0, m0) @@ -399,9 +434,17 @@ class WanVace(WanT2V): self.model.to(self.device) noise_pred_cond = self.model( - latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0] + latent_model_input, + t=timestep, + vace_context=z, + vace_context_scale=context_scale, + **arg_c)[0] noise_pred_uncond = self.model( - latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0] + latent_model_input, + t=timestep, + vace_context=z, + vace_context_scale=context_scale, + **arg_null)[0] noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -433,14 +476,13 @@ class WanVace(WanT2V): class WanVaceMP(WanVace): - def __init__( - self, - config, - checkpoint_dir, - use_usp=False, - ulysses_size=None, - ring_size=None - ): + + def __init__(self, + config, + checkpoint_dir, + use_usp=False, + ulysses_size=None, + ring_size=None): self.config = config self.checkpoint_dir = checkpoint_dir self.use_usp = use_usp @@ -457,7 +499,8 @@ class WanVaceMP(WanVace): self.device = 'cpu' if torch.cuda.is_available() else 'cpu' self.vid_proc = VaceVideoProcessor( - downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]), + downsample=tuple( + [x * y for x, y in zip(config.vae_stride, config.patch_size)]), min_area=480 * 832, max_area=480 * 832, min_fps=self.config.sample_fps, @@ -466,20 +509,30 @@ class WanVaceMP(WanVace): seq_len=32760, keep_last=True) - def dynamic_load(self): if hasattr(self, 'inference_pids') and self.inference_pids is not None: return - gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count() + gpu_infer = os.environ.get( + 'LOCAL_WORLD_SIZE') or torch.cuda.device_count() pmi_rank = int(os.environ['RANK']) pmi_world_size = int(os.environ['WORLD_SIZE']) - in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)] + in_q_list = [ + torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer) + ] out_q = torch.multiprocessing.Manager().Queue() - initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)] - context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False) + initialized_events = [ + torch.multiprocessing.Manager().Event() for _ in range(gpu_infer) + ] + context = mp.spawn( + self.mp_worker, + nprocs=gpu_infer, + args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, + initialized_events, self), + join=False) all_initialized = False while not all_initialized: - all_initialized = all(event.is_set() for event in initialized_events) + all_initialized = all( + event.is_set() for event in initialized_events) if not all_initialized: time.sleep(0.1) print('Inference model is initialized', flush=True) @@ -495,12 +548,19 @@ class WanVaceMP(WanVace): if isinstance(data, torch.Tensor): data = data.to(device) elif isinstance(data, list): - data = [self.transfer_data_to_cuda(subdata, device) for subdata in data] + data = [ + self.transfer_data_to_cuda(subdata, device) + for subdata in data + ] elif isinstance(data, dict): - data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()} + data = { + key: self.transfer_data_to_cuda(val, device) + for key, val in data.items() + } return data - def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env): + def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, + out_q, initialized_events, work_env): try: world_size = pmi_world_size * gpu_infer rank = pmi_rank * gpu_infer + gpu @@ -511,19 +571,19 @@ class WanVaceMP(WanVace): backend='nccl', init_method='env://', rank=rank, - world_size=world_size - ) + world_size=world_size) - from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) + from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, + ) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=self.ring_size or 1, - ulysses_degree=self.ulysses_size or 1 - ) + ulysses_degree=self.ulysses_size or 1) num_train_timesteps = self.config.num_train_timesteps param_dtype = self.config.param_dtype @@ -532,14 +592,17 @@ class WanVaceMP(WanVace): text_len=self.config.text_len, dtype=self.config.t5_dtype, device=torch.device('cpu'), - checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint), - tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer), + checkpoint_path=os.path.join(self.checkpoint_dir, + self.config.t5_checkpoint), + tokenizer_path=os.path.join(self.checkpoint_dir, + self.config.t5_tokenizer), shard_fn=shard_fn if True else None) text_encoder.model.to(gpu) vae_stride = self.config.vae_stride patch_size = self.config.patch_size vae = WanVAE( - vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint), + vae_pth=os.path.join(self.checkpoint_dir, + self.config.vae_checkpoint), device=gpu) logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") model = VaceWanModel.from_pretrained(self.checkpoint_dir) @@ -547,9 +610,12 @@ class WanVaceMP(WanVace): if self.use_usp: from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward, - usp_dit_forward_vace) + + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace, + ) for block in model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @@ -557,7 +623,8 @@ class WanVaceMP(WanVace): block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) model.forward = types.MethodType(usp_dit_forward, model) - model.forward_vace = types.MethodType(usp_dit_forward_vace, model) + model.forward_vace = types.MethodType(usp_dit_forward_vace, + model) sp_size = get_sequence_parallel_world_size() else: sp_size = 1 @@ -577,7 +644,8 @@ class WanVaceMP(WanVace): shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item input_frames = self.transfer_data_to_cuda(input_frames, gpu) input_masks = self.transfer_data_to_cuda(input_masks, gpu) - input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu) + input_ref_images = self.transfer_data_to_cuda( + input_ref_images, gpu) if n_prompt == "": n_prompt = sample_neg_prompt @@ -589,8 +657,10 @@ class WanVaceMP(WanVace): context_null = text_encoder([n_prompt], gpu) # vace context encode - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae) - m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride) + z0 = self.vace_encode_frames( + input_frames, input_ref_images, masks=input_masks, vae=vae) + m0 = self.vace_encode_masks( + input_masks, input_ref_images, vae_stride=vae_stride) z = self.vace_latent(z0, m0) target_shape = list(z0[0].shape) @@ -616,7 +686,8 @@ class WanVaceMP(WanVace): no_sync = getattr(model, 'no_sync', noop_no_sync) # evaluation mode - with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync(): + with amp.autocast( + dtype=param_dtype), torch.no_grad(), no_sync(): if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( @@ -631,7 +702,8 @@ class WanVaceMP(WanVace): num_train_timesteps=num_train_timesteps, shift=1, use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + sampling_sigmas = get_sampling_sigmas( + sampling_steps, shift) timesteps, _ = retrieve_timesteps( sample_scheduler, device=gpu, @@ -653,14 +725,20 @@ class WanVaceMP(WanVace): model.to(gpu) noise_pred_cond = model( - latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[ - 0] + latent_model_input, + t=timestep, + vace_context=z, + vace_context_scale=context_scale, + **arg_c)[0] noise_pred_uncond = model( - latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, + latent_model_input, + t=timestep, + vace_context=z, + vace_context_scale=context_scale, **arg_null)[0] noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) + noise_pred_cond - noise_pred_uncond) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), @@ -673,7 +751,8 @@ class WanVaceMP(WanVace): torch.cuda.empty_cache() x0 = latents if rank == 0: - videos = self.decode_latent(x0, input_ref_images, vae=vae) + videos = self.decode_latent( + x0, input_ref_images, vae=vae) del noise, latents del sample_scheduler @@ -691,8 +770,6 @@ class WanVaceMP(WanVace): print(trace_info, flush=True) print(e, flush=True) - - def generate(self, input_prompt, input_frames, @@ -709,8 +786,10 @@ class WanVaceMP(WanVace): seed=-1, offload_model=True): - input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, - shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model) + input_data = (input_prompt, input_frames, input_masks, input_ref_images, + size, frame_num, context_scale, shift, sample_solver, + sampling_steps, guide_scale, n_prompt, seed, + offload_model) for in_q in self.in_q_list: in_q.put(input_data) value_output = self.out_q.get()