mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-17 21:07:41 +00:00
Compare commits
4 Commits
d3a0e7b077
...
035581ad65
Author | SHA1 | Date | |
---|---|---|---|
|
035581ad65 | ||
|
e5a741309d | ||
|
76e9427657 | ||
|
76bceb2fe5 |
393
.style.yapf
Normal file
393
.style.yapf
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
[style]
|
||||||
|
# Align closing bracket with visual indentation.
|
||||||
|
align_closing_bracket_with_visual_indent=False
|
||||||
|
|
||||||
|
# Allow dictionary keys to exist on multiple lines. For example:
|
||||||
|
#
|
||||||
|
# x = {
|
||||||
|
# ('this is the first element of a tuple',
|
||||||
|
# 'this is the second element of a tuple'):
|
||||||
|
# value,
|
||||||
|
# }
|
||||||
|
allow_multiline_dictionary_keys=False
|
||||||
|
|
||||||
|
# Allow lambdas to be formatted on more than one line.
|
||||||
|
allow_multiline_lambdas=False
|
||||||
|
|
||||||
|
# Allow splitting before a default / named assignment in an argument list.
|
||||||
|
allow_split_before_default_or_named_assigns=False
|
||||||
|
|
||||||
|
# Allow splits before the dictionary value.
|
||||||
|
allow_split_before_dict_value=True
|
||||||
|
|
||||||
|
# Let spacing indicate operator precedence. For example:
|
||||||
|
#
|
||||||
|
# a = 1 * 2 + 3 / 4
|
||||||
|
# b = 1 / 2 - 3 * 4
|
||||||
|
# c = (1 + 2) * (3 - 4)
|
||||||
|
# d = (1 - 2) / (3 + 4)
|
||||||
|
# e = 1 * 2 - 3
|
||||||
|
# f = 1 + 2 + 3 + 4
|
||||||
|
#
|
||||||
|
# will be formatted as follows to indicate precedence:
|
||||||
|
#
|
||||||
|
# a = 1*2 + 3/4
|
||||||
|
# b = 1/2 - 3*4
|
||||||
|
# c = (1+2) * (3-4)
|
||||||
|
# d = (1-2) / (3+4)
|
||||||
|
# e = 1*2 - 3
|
||||||
|
# f = 1 + 2 + 3 + 4
|
||||||
|
#
|
||||||
|
arithmetic_precedence_indication=False
|
||||||
|
|
||||||
|
# Number of blank lines surrounding top-level function and class
|
||||||
|
# definitions.
|
||||||
|
blank_lines_around_top_level_definition=2
|
||||||
|
|
||||||
|
# Insert a blank line before a class-level docstring.
|
||||||
|
blank_line_before_class_docstring=False
|
||||||
|
|
||||||
|
# Insert a blank line before a module docstring.
|
||||||
|
blank_line_before_module_docstring=False
|
||||||
|
|
||||||
|
# Insert a blank line before a 'def' or 'class' immediately nested
|
||||||
|
# within another 'def' or 'class'. For example:
|
||||||
|
#
|
||||||
|
# class Foo:
|
||||||
|
# # <------ this blank line
|
||||||
|
# def method():
|
||||||
|
# ...
|
||||||
|
blank_line_before_nested_class_or_def=True
|
||||||
|
|
||||||
|
# Do not split consecutive brackets. Only relevant when
|
||||||
|
# dedent_closing_brackets is set. For example:
|
||||||
|
#
|
||||||
|
# call_func_that_takes_a_dict(
|
||||||
|
# {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# would reformat to:
|
||||||
|
#
|
||||||
|
# call_func_that_takes_a_dict({
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# })
|
||||||
|
coalesce_brackets=False
|
||||||
|
|
||||||
|
# The column limit.
|
||||||
|
column_limit=80
|
||||||
|
|
||||||
|
# The style for continuation alignment. Possible values are:
|
||||||
|
#
|
||||||
|
# - SPACE: Use spaces for continuation alignment. This is default behavior.
|
||||||
|
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
|
||||||
|
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
|
||||||
|
# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
|
||||||
|
# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
|
||||||
|
# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
|
||||||
|
# cannot vertically align continuation lines with indent characters.
|
||||||
|
continuation_align_style=SPACE
|
||||||
|
|
||||||
|
# Indent width used for line continuations.
|
||||||
|
continuation_indent_width=4
|
||||||
|
|
||||||
|
# Put closing brackets on a separate line, dedented, if the bracketed
|
||||||
|
# expression can't fit in a single line. Applies to all kinds of brackets,
|
||||||
|
# including function definitions and calls. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# } # <--- this bracket is dedented and on a separate line
|
||||||
|
#
|
||||||
|
# time_series = self.remote_client.query_entity_counters(
|
||||||
|
# entity='dev3246.region1',
|
||||||
|
# key='dns.query_latency_tcp',
|
||||||
|
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
|
||||||
|
# start_ts=now()-timedelta(days=3),
|
||||||
|
# end_ts=now(),
|
||||||
|
# ) # <--- this bracket is dedented and on a separate line
|
||||||
|
dedent_closing_brackets=False
|
||||||
|
|
||||||
|
# Disable the heuristic which places each list element on a separate line
|
||||||
|
# if the list is comma-terminated.
|
||||||
|
disable_ending_comma_heuristic=False
|
||||||
|
|
||||||
|
# Place each dictionary entry onto its own line.
|
||||||
|
each_dict_entry_on_separate_line=True
|
||||||
|
|
||||||
|
# Require multiline dictionary even if it would normally fit on one line.
|
||||||
|
# For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1'
|
||||||
|
# }
|
||||||
|
force_multiline_dict=False
|
||||||
|
|
||||||
|
# The regex for an i18n comment. The presence of this comment stops
|
||||||
|
# reformatting of that line, because the comments are required to be
|
||||||
|
# next to the string they translate.
|
||||||
|
i18n_comment=#\..*
|
||||||
|
|
||||||
|
# The i18n function call names. The presence of this function stops
|
||||||
|
# reformattting on that line, because the string it has cannot be moved
|
||||||
|
# away from the i18n comment.
|
||||||
|
i18n_function_call=N_, _
|
||||||
|
|
||||||
|
# Indent blank lines.
|
||||||
|
indent_blank_lines=False
|
||||||
|
|
||||||
|
# Put closing brackets on a separate line, indented, if the bracketed
|
||||||
|
# expression can't fit in a single line. Applies to all kinds of brackets,
|
||||||
|
# including function definitions and calls. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# } # <--- this bracket is indented and on a separate line
|
||||||
|
#
|
||||||
|
# time_series = self.remote_client.query_entity_counters(
|
||||||
|
# entity='dev3246.region1',
|
||||||
|
# key='dns.query_latency_tcp',
|
||||||
|
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
|
||||||
|
# start_ts=now()-timedelta(days=3),
|
||||||
|
# end_ts=now(),
|
||||||
|
# ) # <--- this bracket is indented and on a separate line
|
||||||
|
indent_closing_brackets=False
|
||||||
|
|
||||||
|
# Indent the dictionary value if it cannot fit on the same line as the
|
||||||
|
# dictionary key. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1':
|
||||||
|
# 'value1',
|
||||||
|
# 'key2': value1 +
|
||||||
|
# value2,
|
||||||
|
# }
|
||||||
|
indent_dictionary_value=True
|
||||||
|
|
||||||
|
# The number of columns to use for indentation.
|
||||||
|
indent_width=4
|
||||||
|
|
||||||
|
# Join short lines into one line. E.g., single line 'if' statements.
|
||||||
|
join_multiple_lines=False
|
||||||
|
|
||||||
|
# Do not include spaces around selected binary operators. For example:
|
||||||
|
#
|
||||||
|
# 1 + 2 * 3 - 4 / 5
|
||||||
|
#
|
||||||
|
# will be formatted as follows when configured with "*,/":
|
||||||
|
#
|
||||||
|
# 1 + 2*3 - 4/5
|
||||||
|
no_spaces_around_selected_binary_operators=
|
||||||
|
|
||||||
|
# Use spaces around default or named assigns.
|
||||||
|
spaces_around_default_or_named_assign=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '{' and before the ending '}' dict delimiters.
|
||||||
|
#
|
||||||
|
# {1: 2}
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# { 1: 2 }
|
||||||
|
spaces_around_dict_delimiters=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '[' and before the ending ']' list delimiters.
|
||||||
|
#
|
||||||
|
# [1, 2]
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# [ 1, 2 ]
|
||||||
|
spaces_around_list_delimiters=False
|
||||||
|
|
||||||
|
# Use spaces around the power operator.
|
||||||
|
spaces_around_power_operator=False
|
||||||
|
|
||||||
|
# Use spaces around the subscript / slice operator. For example:
|
||||||
|
#
|
||||||
|
# my_list[1 : 10 : 2]
|
||||||
|
spaces_around_subscript_colon=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
|
||||||
|
#
|
||||||
|
# (1, 2, 3)
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# ( 1, 2, 3 )
|
||||||
|
spaces_around_tuple_delimiters=False
|
||||||
|
|
||||||
|
# The number of spaces required before a trailing comment.
|
||||||
|
# This can be a single value (representing the number of spaces
|
||||||
|
# before each trailing comment) or list of values (representing
|
||||||
|
# alignment column values; trailing comments within a block will
|
||||||
|
# be aligned to the first column value that is greater than the maximum
|
||||||
|
# line length within the block). For example:
|
||||||
|
#
|
||||||
|
# With spaces_before_comment=5:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
|
||||||
|
#
|
||||||
|
# With spaces_before_comment=15, 20:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values
|
||||||
|
# two + two # More adding
|
||||||
|
#
|
||||||
|
# longer_statement # This is a longer statement
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
|
||||||
|
# two + two # More adding
|
||||||
|
#
|
||||||
|
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
spaces_before_comment=2
|
||||||
|
|
||||||
|
# Insert a space between the ending comma and closing bracket of a list,
|
||||||
|
# etc.
|
||||||
|
space_between_ending_comma_and_closing_bracket=False
|
||||||
|
|
||||||
|
# Use spaces inside brackets, braces, and parentheses. For example:
|
||||||
|
#
|
||||||
|
# method_call( 1 )
|
||||||
|
# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
|
||||||
|
# my_set = { 1, 2, 3 }
|
||||||
|
space_inside_brackets=False
|
||||||
|
|
||||||
|
# Split before arguments
|
||||||
|
split_all_comma_separated_values=False
|
||||||
|
|
||||||
|
# Split before arguments, but do not split all subexpressions recursively
|
||||||
|
# (unless needed).
|
||||||
|
split_all_top_level_comma_separated_values=False
|
||||||
|
|
||||||
|
# Split before arguments if the argument list is terminated by a
|
||||||
|
# comma.
|
||||||
|
split_arguments_when_comma_terminated=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
|
||||||
|
# rather than after.
|
||||||
|
split_before_arithmetic_operator=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before '&', '|' or '^' rather than
|
||||||
|
# after.
|
||||||
|
split_before_bitwise_operator=False
|
||||||
|
|
||||||
|
# Split before the closing bracket if a list or dict literal doesn't fit on
|
||||||
|
# a single line.
|
||||||
|
split_before_closing_bracket=True
|
||||||
|
|
||||||
|
# Split before a dictionary or set generator (comp_for). For example, note
|
||||||
|
# the split before the 'for':
|
||||||
|
#
|
||||||
|
# foo = {
|
||||||
|
# variable: 'Hello world, have a nice day!'
|
||||||
|
# for variable in bar if variable != 42
|
||||||
|
# }
|
||||||
|
split_before_dict_set_generator=False
|
||||||
|
|
||||||
|
# Split before the '.' if we need to split a longer expression:
|
||||||
|
#
|
||||||
|
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# foo = ('This is a really long string: {}, {}, {}, {}'
|
||||||
|
# .format(a, b, c, d))
|
||||||
|
split_before_dot=False
|
||||||
|
|
||||||
|
# Split after the opening paren which surrounds an expression if it doesn't
|
||||||
|
# fit on a single line.
|
||||||
|
split_before_expression_after_opening_paren=True
|
||||||
|
|
||||||
|
# If an argument / parameter list is going to be split, then split before
|
||||||
|
# the first argument.
|
||||||
|
split_before_first_argument=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before 'and' or 'or' rather than
|
||||||
|
# after.
|
||||||
|
split_before_logical_operator=False
|
||||||
|
|
||||||
|
# Split named assignments onto individual lines.
|
||||||
|
split_before_named_assigns=True
|
||||||
|
|
||||||
|
# Set to True to split list comprehensions and generators that have
|
||||||
|
# non-trivial expressions and multiple clauses before each of these
|
||||||
|
# clauses. For example:
|
||||||
|
#
|
||||||
|
# result = [
|
||||||
|
# a_long_var + 100 for a_long_var in xrange(1000)
|
||||||
|
# if a_long_var % 10]
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# result = [
|
||||||
|
# a_long_var + 100
|
||||||
|
# for a_long_var in xrange(1000)
|
||||||
|
# if a_long_var % 10]
|
||||||
|
split_complex_comprehension=True
|
||||||
|
|
||||||
|
# The penalty for splitting right after the opening bracket.
|
||||||
|
split_penalty_after_opening_bracket=300
|
||||||
|
|
||||||
|
# The penalty for splitting the line after a unary operator.
|
||||||
|
split_penalty_after_unary_operator=10000
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the '+', '-', '*', '/', '//',
|
||||||
|
# ``%``, and '@' operators.
|
||||||
|
split_penalty_arithmetic_operator=300
|
||||||
|
|
||||||
|
# The penalty for splitting right before an if expression.
|
||||||
|
split_penalty_before_if_expr=0
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the '&', '|', and '^'
|
||||||
|
# operators.
|
||||||
|
split_penalty_bitwise_operator=300
|
||||||
|
|
||||||
|
# The penalty for splitting a list comprehension or generator
|
||||||
|
# expression.
|
||||||
|
split_penalty_comprehension=2100
|
||||||
|
|
||||||
|
# The penalty for characters over the column limit.
|
||||||
|
split_penalty_excess_character=7000
|
||||||
|
|
||||||
|
# The penalty incurred by adding a line split to the unwrapped line. The
|
||||||
|
# more line splits added the higher the penalty.
|
||||||
|
split_penalty_for_added_line_split=30
|
||||||
|
|
||||||
|
# The penalty of splitting a list of "import as" names. For example:
|
||||||
|
#
|
||||||
|
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
|
||||||
|
# long_argument_2,
|
||||||
|
# long_argument_3)
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# from a_very_long_or_indented_module_name_yada_yad import (
|
||||||
|
# long_argument_1, long_argument_2, long_argument_3)
|
||||||
|
split_penalty_import_names=0
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the 'and' and 'or'
|
||||||
|
# operators.
|
||||||
|
split_penalty_logical_operator=300
|
||||||
|
|
||||||
|
# Use the Tab character for indentation.
|
||||||
|
use_tabs=False
|
5
Makefile
Normal file
5
Makefile
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
.PHONY: format
|
||||||
|
|
||||||
|
format:
|
||||||
|
isort generate.py gradio wan
|
||||||
|
yapf -i -r *.py generate.py gradio wan
|
@ -643,7 +643,7 @@ If you find our work helpful, please cite us.
|
|||||||
```
|
```
|
||||||
@article{wan2025,
|
@article{wan2025,
|
||||||
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
||||||
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
||||||
journal = {arXiv preprint arXiv:2503.20314},
|
journal = {arXiv preprint arXiv:2503.20314},
|
||||||
year={2025}
|
year={2025}
|
||||||
}
|
}
|
||||||
|
81
generate.py
81
generate.py
@ -1,28 +1,33 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
import torch, random
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
"t2v-1.3B": {
|
"t2v-1.3B": {
|
||||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
},
|
},
|
||||||
"t2v-14B": {
|
"t2v-14B": {
|
||||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
},
|
},
|
||||||
"t2i-14B": {
|
"t2i-14B": {
|
||||||
"prompt": "一个朴素端庄的美人",
|
"prompt": "一个朴素端庄的美人",
|
||||||
@ -34,20 +39,24 @@ EXAMPLE_PROMPT = {
|
|||||||
"examples/i2v_input.JPG",
|
"examples/i2v_input.JPG",
|
||||||
},
|
},
|
||||||
"flf2v-14B": {
|
"flf2v-14B": {
|
||||||
"prompt":
|
"prompt":
|
||||||
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
|
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
|
||||||
"first_frame":
|
"first_frame":
|
||||||
"examples/flf2v_input_first_frame.png",
|
"examples/flf2v_input_first_frame.png",
|
||||||
"last_frame":
|
"last_frame":
|
||||||
"examples/flf2v_input_last_frame.png",
|
"examples/flf2v_input_last_frame.png",
|
||||||
},
|
},
|
||||||
"vace-1.3B": {
|
"vace-1.3B": {
|
||||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
"src_ref_images":
|
||||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
},
|
},
|
||||||
"vace-14B": {
|
"vace-14B": {
|
||||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
"src_ref_images":
|
||||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +73,6 @@ def _validate_args(args):
|
|||||||
if "i2v" in args.task:
|
if "i2v" in args.task:
|
||||||
args.sample_steps = 40
|
args.sample_steps = 40
|
||||||
|
|
||||||
|
|
||||||
if args.sample_shift is None:
|
if args.sample_shift is None:
|
||||||
args.sample_shift = 5.0
|
args.sample_shift = 5.0
|
||||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||||
@ -72,7 +80,6 @@ def _validate_args(args):
|
|||||||
elif "flf2v" in args.task or "vace" in args.task:
|
elif "flf2v" in args.task or "vace" in args.task:
|
||||||
args.sample_shift = 16
|
args.sample_shift = 16
|
||||||
|
|
||||||
|
|
||||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||||
if args.frame_num is None:
|
if args.frame_num is None:
|
||||||
args.frame_num = 1 if "t2i" in args.task else 81
|
args.frame_num = 1 if "t2i" in args.task else 81
|
||||||
@ -167,7 +174,8 @@ def _parse_args():
|
|||||||
"--src_ref_images",
|
"--src_ref_images",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file list of the source reference images. Separated by ','. Default None.")
|
help="The file list of the source reference images. Separated by ','. Default None."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
type=str,
|
type=str,
|
||||||
@ -209,12 +217,14 @@ def _parse_args():
|
|||||||
"--first_frame",
|
"--first_frame",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="[first-last frame to video] The image (first frame) to generate the video from.")
|
help="[first-last frame to video] The image (first frame) to generate the video from."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--last_frame",
|
"--last_frame",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="[first-last frame to video] The image (last frame) to generate the video from.")
|
help="[first-last frame to video] The image (last frame) to generate the video from."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_solver",
|
"--sample_solver",
|
||||||
type=str,
|
type=str,
|
||||||
@ -281,8 +291,10 @@ def generate(args):
|
|||||||
|
|
||||||
if args.ulysses_size > 1 or args.ring_size > 1:
|
if args.ulysses_size > 1 or args.ring_size > 1:
|
||||||
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
from xfuser.core.distributed import (
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
@ -295,11 +307,12 @@ def generate(args):
|
|||||||
if args.use_prompt_extend:
|
if args.use_prompt_extend:
|
||||||
if args.prompt_extend_method == "dashscope":
|
if args.prompt_extend_method == "dashscope":
|
||||||
prompt_expander = DashScopePromptExpander(
|
prompt_expander = DashScopePromptExpander(
|
||||||
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
|
model_name=args.prompt_extend_model,
|
||||||
|
is_vl="i2v" in args.task or "flf2v" in args.task)
|
||||||
elif args.prompt_extend_method == "local_qwen":
|
elif args.prompt_extend_method == "local_qwen":
|
||||||
prompt_expander = QwenPromptExpander(
|
prompt_expander = QwenPromptExpander(
|
||||||
model_name=args.prompt_extend_model,
|
model_name=args.prompt_extend_model,
|
||||||
is_vl="i2v" in args.task,
|
is_vl="i2v" in args.task or "flf2v" in args.task,
|
||||||
device=rank)
|
device=rank)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -482,21 +495,22 @@ def generate(args):
|
|||||||
sampling_steps=args.sample_steps,
|
sampling_steps=args.sample_steps,
|
||||||
guide_scale=args.sample_guide_scale,
|
guide_scale=args.sample_guide_scale,
|
||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model
|
offload_model=args.offload_model)
|
||||||
)
|
|
||||||
elif "vace" in args.task:
|
elif "vace" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||||
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
||||||
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
|
args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
|
||||||
|
"src_ref_images", None)
|
||||||
|
|
||||||
logging.info(f"Input prompt: {args.prompt}")
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
||||||
logging.info("Extending prompt ...")
|
logging.info("Extending prompt ...")
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
prompt = prompt_expander.forward(args.prompt)
|
prompt = prompt_expander.forward(args.prompt)
|
||||||
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
logging.info(
|
||||||
|
f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
||||||
input_prompt = [prompt]
|
input_prompt = [prompt]
|
||||||
else:
|
else:
|
||||||
input_prompt = [None]
|
input_prompt = [None]
|
||||||
@ -517,10 +531,11 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
|
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
|
||||||
[args.src_mask],
|
[args.src_video], [args.src_mask], [
|
||||||
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
|
None if args.src_ref_images is None else
|
||||||
args.frame_num, SIZE_CONFIGS[args.size], device)
|
args.src_ref_images.split(',')
|
||||||
|
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||||
|
|
||||||
logging.info(f"Generating video...")
|
logging.info(f"Generating video...")
|
||||||
video = wan_vace.generate(
|
video = wan_vace.generate(
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -11,7 +11,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang):
|
|||||||
return prompt_output.prompt
|
return prompt_output.prompt
|
||||||
|
|
||||||
|
|
||||||
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
|
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
guide_scale, shift_scale, seed, n_prompt):
|
resolution, sd_steps, guide_scale, shift_scale, seed,
|
||||||
|
n_prompt):
|
||||||
|
|
||||||
if resolution == '------':
|
if resolution == '------':
|
||||||
print(
|
print(
|
||||||
'Please specify the resolution ckpt dir or specify the resolution'
|
'Please specify the resolution ckpt dir or specify the resolution')
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re
|
|||||||
offload_model=True)
|
offload_model=True)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(
|
print('Sorry, currently only 720P is supported.')
|
||||||
'Sorry, currently only 720P is supported.'
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cache_video(
|
cache_video(
|
||||||
@ -191,14 +190,17 @@ def gradio_interface():
|
|||||||
|
|
||||||
run_p_button.click(
|
run_p_button.click(
|
||||||
fn=prompt_enc,
|
fn=prompt_enc,
|
||||||
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
|
inputs=[
|
||||||
|
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
|
tar_lang
|
||||||
|
],
|
||||||
outputs=[flf2vid_prompt])
|
outputs=[flf2vid_prompt])
|
||||||
|
|
||||||
run_flf2v_button.click(
|
run_flf2v_button.click(
|
||||||
fn=flf2v_generation,
|
fn=flf2v_generation,
|
||||||
inputs=[
|
inputs=[
|
||||||
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
|
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
guide_scale, shift_scale, seed, n_prompt
|
resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
|
||||||
],
|
],
|
||||||
outputs=[result_gallery],
|
outputs=[result_gallery],
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -11,7 +11,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
200
gradio/vace.py
200
gradio/vace.py
@ -2,36 +2,48 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import datetime
|
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan import WanVace, WanVaceMP
|
from wan import WanVace, WanVaceMP
|
||||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeQueue:
|
class FixedSizeQueue:
|
||||||
|
|
||||||
def __init__(self, max_size):
|
def __init__(self, max_size):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.queue = []
|
self.queue = []
|
||||||
|
|
||||||
def add(self, item):
|
def add(self, item):
|
||||||
self.queue.insert(0, item)
|
self.queue.insert(0, item)
|
||||||
if len(self.queue) > self.max_size:
|
if len(self.queue) > self.max_size:
|
||||||
self.queue.pop()
|
self.queue.pop()
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
return self.queue
|
return self.queue
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.queue)
|
return str(self.queue)
|
||||||
|
|
||||||
|
|
||||||
class VACEInference:
|
class VACEInference:
|
||||||
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
|
||||||
|
def __init__(self,
|
||||||
|
cfg,
|
||||||
|
skip_load=False,
|
||||||
|
gallery_share=True,
|
||||||
|
gallery_share_limit=5):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.save_dir = cfg.save_dir
|
self.save_dir = cfg.save_dir
|
||||||
self.gallery_share = gallery_share
|
self.gallery_share = gallery_share
|
||||||
@ -53,9 +65,7 @@ class VACEInference:
|
|||||||
checkpoint_dir=cfg.ckpt_dir,
|
checkpoint_dir=cfg.ckpt_dir,
|
||||||
use_usp=True,
|
use_usp=True,
|
||||||
ulysses_size=cfg.ulysses_size,
|
ulysses_size=cfg.ulysses_size,
|
||||||
ring_size=cfg.ring_size
|
ring_size=cfg.ring_size)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_ui(self, *args, **kwargs):
|
def create_ui(self, *args, **kwargs):
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@ -80,30 +90,33 @@ class VACEInference:
|
|||||||
with gr.Row(variant='panel', equal_height=True):
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
with gr.Column(scale=1, min_width=0):
|
with gr.Column(scale=1, min_width=0):
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
self.src_ref_image_1 = gr.Image(
|
||||||
height=200,
|
label='src_ref_image_1',
|
||||||
interactive=True,
|
height=200,
|
||||||
type='filepath',
|
interactive=True,
|
||||||
image_mode='RGB',
|
type='filepath',
|
||||||
sources=['upload'],
|
image_mode='RGB',
|
||||||
elem_id="src_ref_image_1",
|
sources=['upload'],
|
||||||
format='png')
|
elem_id="src_ref_image_1",
|
||||||
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
format='png')
|
||||||
height=200,
|
self.src_ref_image_2 = gr.Image(
|
||||||
interactive=True,
|
label='src_ref_image_2',
|
||||||
type='filepath',
|
height=200,
|
||||||
image_mode='RGB',
|
interactive=True,
|
||||||
sources=['upload'],
|
type='filepath',
|
||||||
elem_id="src_ref_image_2",
|
image_mode='RGB',
|
||||||
format='png')
|
sources=['upload'],
|
||||||
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
elem_id="src_ref_image_2",
|
||||||
height=200,
|
format='png')
|
||||||
interactive=True,
|
self.src_ref_image_3 = gr.Image(
|
||||||
type='filepath',
|
label='src_ref_image_3',
|
||||||
image_mode='RGB',
|
height=200,
|
||||||
sources=['upload'],
|
interactive=True,
|
||||||
elem_id="src_ref_image_3",
|
type='filepath',
|
||||||
format='png')
|
image_mode='RGB',
|
||||||
|
sources=['upload'],
|
||||||
|
elem_id="src_ref_image_3",
|
||||||
|
format='png')
|
||||||
with gr.Row(variant='panel', equal_height=True):
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
self.prompt = gr.Textbox(
|
self.prompt = gr.Textbox(
|
||||||
@ -158,10 +171,8 @@ class VACEInference:
|
|||||||
step=0.5,
|
step=0.5,
|
||||||
value=5.0,
|
value=5.0,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
self.infer_seed = gr.Slider(minimum=-1,
|
self.infer_seed = gr.Slider(
|
||||||
maximum=10000000,
|
minimum=-1, maximum=10000000, value=2025, label="Seed")
|
||||||
value=2025,
|
|
||||||
label="Seed")
|
|
||||||
#
|
#
|
||||||
with gr.Accordion(label="Usable without source video", open=False):
|
with gr.Accordion(label="Usable without source video", open=False):
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
@ -176,13 +187,9 @@ class VACEInference:
|
|||||||
value=1280,
|
value=1280,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
self.frame_rate = gr.Textbox(
|
self.frame_rate = gr.Textbox(
|
||||||
label='frame_rate',
|
label='frame_rate', value=16, interactive=True)
|
||||||
value=16,
|
|
||||||
interactive=True)
|
|
||||||
self.num_frames = gr.Textbox(
|
self.num_frames = gr.Textbox(
|
||||||
label='num_frames',
|
label='num_frames', value=81, interactive=True)
|
||||||
value=81,
|
|
||||||
interactive=True)
|
|
||||||
#
|
#
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
@ -201,17 +208,22 @@ class VACEInference:
|
|||||||
allow_preview=True,
|
allow_preview=True,
|
||||||
preview=True)
|
preview=True)
|
||||||
|
|
||||||
|
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
|
||||||
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
|
||||||
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
shift_scale, sample_steps, context_scale, guide_scale,
|
||||||
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
infer_seed, output_height, output_width, frame_rate,
|
||||||
x is not None]
|
num_frames):
|
||||||
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
output_height, output_width, frame_rate, num_frames = int(
|
||||||
[src_mask],
|
output_height), int(output_width), int(frame_rate), int(num_frames)
|
||||||
[src_ref_images],
|
src_ref_images = [
|
||||||
num_frames=num_frames,
|
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
|
||||||
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
if x is not None
|
||||||
device=self.pipe.device)
|
]
|
||||||
|
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
|
||||||
|
[src_video], [src_mask], [src_ref_images],
|
||||||
|
num_frames=num_frames,
|
||||||
|
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
||||||
|
device=self.pipe.device)
|
||||||
video = self.pipe.generate(
|
video = self.pipe.generate(
|
||||||
prompt,
|
prompt,
|
||||||
src_video,
|
src_video,
|
||||||
@ -228,10 +240,17 @@ class VACEInference:
|
|||||||
|
|
||||||
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
||||||
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
||||||
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
video_frames = (
|
||||||
|
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
|
||||||
|
255).cpu().numpy().astype(np.uint8)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
writer = imageio.get_writer(
|
||||||
|
video_path,
|
||||||
|
fps=frame_rate,
|
||||||
|
codec='libx264',
|
||||||
|
quality=8,
|
||||||
|
macro_block_size=1)
|
||||||
for frame in video_frames:
|
for frame in video_frames:
|
||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
writer.close()
|
writer.close()
|
||||||
@ -246,25 +265,57 @@ class VACEInference:
|
|||||||
return [video_path]
|
return [video_path]
|
||||||
|
|
||||||
def set_callbacks(self, **kwargs):
|
def set_callbacks(self, **kwargs):
|
||||||
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
self.gen_inputs = [
|
||||||
|
self.output_gallery, self.src_video, self.src_mask,
|
||||||
|
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
|
||||||
|
self.prompt, self.negative_prompt, self.shift_scale,
|
||||||
|
self.sample_steps, self.context_scale, self.guide_scale,
|
||||||
|
self.infer_seed, self.output_height, self.output_width,
|
||||||
|
self.frame_rate, self.num_frames
|
||||||
|
]
|
||||||
self.gen_outputs = [self.output_gallery]
|
self.gen_outputs = [self.output_gallery]
|
||||||
self.generate_button.click(self.generate,
|
self.generate_button.click(
|
||||||
inputs=self.gen_inputs,
|
self.generate,
|
||||||
outputs=self.gen_outputs,
|
inputs=self.gen_inputs,
|
||||||
queue=True)
|
outputs=self.gen_outputs,
|
||||||
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
queue=True)
|
||||||
|
self.refresh_button.click(
|
||||||
|
lambda x: self.gallery_share_data.get()
|
||||||
|
if self.gallery_share else x,
|
||||||
|
inputs=[self.output_gallery],
|
||||||
|
outputs=[self.output_gallery])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
description='Argparser for VACE-WAN Demo:\n')
|
||||||
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
parser.add_argument(
|
||||||
|
'--server_port', dest='server_port', help='', type=int, default=7860)
|
||||||
|
parser.add_argument(
|
||||||
|
'--server_name', dest='server_name', help='', default='0.0.0.0')
|
||||||
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
||||||
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
||||||
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
|
parser.add_argument(
|
||||||
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
|
"--mp",
|
||||||
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
|
action="store_true",
|
||||||
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
|
help="Use Multi-GPUs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="vace-14B",
|
||||||
|
choices=list(WAN_CONFIGS.keys()),
|
||||||
|
help="The model name to run.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ulysses_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ulysses parallelism in DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ring_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ring attention parallelism in DiT.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_dir",
|
"--ckpt_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -284,12 +335,15 @@ if __name__ == '__main__':
|
|||||||
os.makedirs(args.save_dir, exist_ok=True)
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
infer_gr = VACEInference(
|
||||||
|
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
||||||
infer_gr.create_ui()
|
infer_gr.create_ui()
|
||||||
infer_gr.set_callbacks()
|
infer_gr.set_callbacks()
|
||||||
allowed_paths = [args.save_dir]
|
allowed_paths = [args.save_dir]
|
||||||
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
demo.queue(status_update_rate=1).launch(
|
||||||
server_port=args.server_port,
|
server_name=args.server_name,
|
||||||
root_path=args.root_path,
|
server_port=args.server_port,
|
||||||
allowed_paths=allowed_paths,
|
root_path=args.root_path,
|
||||||
show_error=True, debug=True)
|
allowed_paths=allowed_paths,
|
||||||
|
show_error=True,
|
||||||
|
debug=True)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from . import configs, distributed, modules
|
from . import configs, distributed, modules
|
||||||
|
from .first_last_frame2video import WanFLF2V
|
||||||
from .image2video import WanI2V
|
from .image2video import WanI2V
|
||||||
from .text2video import WanT2V
|
from .text2video import WanT2V
|
||||||
from .first_last_frame2video import WanFLF2V
|
|
||||||
from .vace import WanVace, WanVaceMP
|
from .vace import WanVace, WanVaceMP
|
||||||
|
@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
|||||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||||
from torch.distributed.utils import _free_storage
|
from torch.distributed.utils import _free_storage
|
||||||
|
|
||||||
|
|
||||||
def shard_model(
|
def shard_model(
|
||||||
model,
|
model,
|
||||||
device_id,
|
device_id,
|
||||||
@ -32,6 +33,7 @@ def shard_model(
|
|||||||
sync_module_states=sync_module_states)
|
sync_module_states=sync_module_states)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def free_model(model):
|
def free_model(model):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, FSDP):
|
if isinstance(m, FSDP):
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
from xfuser.core.distributed import (
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_rank,
|
||||||
get_sp_group)
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group,
|
||||||
|
)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
from ..modules.model import sinusoidal_embedding_1d
|
from ..modules.model import sinusoidal_embedding_1d
|
||||||
@ -63,19 +65,13 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward_vace(
|
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
vace_context,
|
|
||||||
seq_len,
|
|
||||||
kwargs
|
|
||||||
):
|
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
c = torch.cat([
|
c = torch.cat([
|
||||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
||||||
dim=1) for u in c
|
for u in c
|
||||||
])
|
])
|
||||||
|
|
||||||
# arguments
|
# arguments
|
||||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -103,11 +106,12 @@ class WanFLF2V:
|
|||||||
init_on_cpu = False
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -181,8 +185,10 @@ class WanFLF2V:
|
|||||||
"""
|
"""
|
||||||
first_frame_size = first_frame.size
|
first_frame_size = first_frame.size
|
||||||
last_frame_size = last_frame.size
|
last_frame_size = last_frame.size
|
||||||
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
|
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
||||||
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
|
self.device)
|
||||||
|
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
F = frame_num
|
F = frame_num
|
||||||
first_frame_h, first_frame_w = first_frame.shape[1:]
|
first_frame_h, first_frame_w = first_frame.shape[1:]
|
||||||
@ -199,8 +205,7 @@ class WanFLF2V:
|
|||||||
# 1. resize
|
# 1. resize
|
||||||
last_frame_resize_ratio = max(
|
last_frame_resize_ratio = max(
|
||||||
first_frame_size[0] / last_frame_size[0],
|
first_frame_size[0] / last_frame_size[0],
|
||||||
first_frame_size[1] / last_frame_size[1]
|
first_frame_size[1] / last_frame_size[1])
|
||||||
)
|
|
||||||
last_frame_size = [
|
last_frame_size = [
|
||||||
round(last_frame_size[0] * last_frame_resize_ratio),
|
round(last_frame_size[0] * last_frame_resize_ratio),
|
||||||
round(last_frame_size[1] * last_frame_resize_ratio),
|
round(last_frame_size[1] * last_frame_resize_ratio),
|
||||||
@ -216,8 +221,7 @@ class WanFLF2V:
|
|||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
seed_g.manual_seed(seed)
|
seed_g.manual_seed(seed)
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
16,
|
16, (F - 1) // 4 + 1,
|
||||||
(F - 1) // 4 + 1,
|
|
||||||
lat_h,
|
lat_h,
|
||||||
lat_w,
|
lat_w,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -225,8 +229,11 @@ class WanFLF2V:
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
||||||
msk[:, 1: -1] = 0
|
msk[:, 1:-1] = 0
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
msk = torch.concat([
|
||||||
|
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
@ -247,7 +254,8 @@ class WanFLF2V:
|
|||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
self.clip.model.to(self.device)
|
self.clip.model.to(self.device)
|
||||||
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
|
clip_context = self.clip.visual(
|
||||||
|
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.clip.model.cpu()
|
self.clip.model.cpu()
|
||||||
|
|
||||||
@ -256,15 +264,14 @@ class WanFLF2V:
|
|||||||
torch.nn.functional.interpolate(
|
torch.nn.functional.interpolate(
|
||||||
first_frame[None].cpu(),
|
first_frame[None].cpu(),
|
||||||
size=(first_frame_h, first_frame_w),
|
size=(first_frame_h, first_frame_w),
|
||||||
mode='bicubic'
|
mode='bicubic').transpose(0, 1),
|
||||||
).transpose(0, 1),
|
|
||||||
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
|
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
|
||||||
torch.nn.functional.interpolate(
|
torch.nn.functional.interpolate(
|
||||||
last_frame[None].cpu(),
|
last_frame[None].cpu(),
|
||||||
size=(first_frame_h, first_frame_w),
|
size=(first_frame_h, first_frame_w),
|
||||||
mode='bicubic'
|
mode='bicubic').transpose(0, 1),
|
||||||
).transpose(0, 1),
|
],
|
||||||
], dim=1).to(self.device)
|
dim=1).to(self.device)
|
||||||
])[0]
|
])[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -103,11 +106,12 @@ class WanI2V:
|
|||||||
init_on_cpu = False
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -196,8 +200,7 @@ class WanI2V:
|
|||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
seed_g.manual_seed(seed)
|
seed_g.manual_seed(seed)
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
16,
|
16, (F - 1) // 4 + 1,
|
||||||
(F - 1) // 4 + 1,
|
|
||||||
lat_h,
|
lat_h,
|
||||||
lat_w,
|
lat_w,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
@ -273,7 +273,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
nn.Linear(ffn_dim, dim))
|
nn.Linear(ffn_dim, dim))
|
||||||
|
|
||||||
# modulation
|
# modulation
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -332,7 +332,7 @@ class Head(nn.Module):
|
|||||||
self.head = nn.Linear(dim, out_dim)
|
self.head = nn.Linear(dim, out_dim)
|
||||||
|
|
||||||
# modulation
|
# modulation
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||||
|
|
||||||
def forward(self, x, e):
|
def forward(self, x, e):
|
||||||
r"""
|
r"""
|
||||||
@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module):
|
|||||||
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||||
torch.nn.LayerNorm(out_dim))
|
torch.nn.LayerNorm(out_dim))
|
||||||
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
||||||
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
self.emb_pos = nn.Parameter(
|
||||||
|
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
if hasattr(self, 'emb_pos'):
|
if hasattr(self, 'emb_pos'):
|
||||||
|
@ -3,23 +3,24 @@ import torch
|
|||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.configuration_utils import register_to_config
|
from diffusers.configuration_utils import register_to_config
|
||||||
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
|
||||||
|
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
cross_attn_type,
|
cross_attn_type,
|
||||||
dim,
|
dim,
|
||||||
ffn_dim,
|
ffn_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
block_id=0
|
block_id=0):
|
||||||
):
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
qk_norm, cross_attn_norm, eps)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
if block_id == 0:
|
if block_id == 0:
|
||||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||||
@ -39,19 +40,19 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
|
|
||||||
|
|
||||||
class BaseWanAttentionBlock(WanAttentionBlock):
|
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
cross_attn_type,
|
cross_attn_type,
|
||||||
dim,
|
dim,
|
||||||
ffn_dim,
|
ffn_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
block_id=None
|
block_id=None):
|
||||||
):
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
qk_norm, cross_attn_norm, eps)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
|
|
||||||
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
||||||
@ -62,6 +63,7 @@ class BaseWanAttentionBlock(WanAttentionBlock):
|
|||||||
|
|
||||||
|
|
||||||
class VaceWanModel(WanModel):
|
class VaceWanModel(WanModel):
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vace_layers=None,
|
vace_layers=None,
|
||||||
@ -81,42 +83,57 @@ class VaceWanModel(WanModel):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6):
|
eps=1e-6):
|
||||||
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
|
||||||
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
freq_dim, text_dim, out_dim, num_heads, num_layers,
|
||||||
|
window_size, qk_norm, cross_attn_norm, eps)
|
||||||
|
|
||||||
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
self.vace_layers = [i for i in range(0, self.num_layers, 2)
|
||||||
|
] if vace_layers is None else vace_layers
|
||||||
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
||||||
|
|
||||||
assert 0 in self.vace_layers
|
assert 0 in self.vace_layers
|
||||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
self.vace_layers_mapping = {
|
||||||
|
i: n for n, i in enumerate(self.vace_layers)
|
||||||
|
}
|
||||||
|
|
||||||
# blocks
|
# blocks
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
BaseWanAttentionBlock(
|
||||||
self.cross_attn_norm, self.eps,
|
't2v_cross_attn',
|
||||||
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.window_size,
|
||||||
|
self.qk_norm,
|
||||||
|
self.cross_attn_norm,
|
||||||
|
self.eps,
|
||||||
|
block_id=self.vace_layers_mapping[i]
|
||||||
|
if i in self.vace_layers else None)
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
# vace blocks
|
# vace blocks
|
||||||
self.vace_blocks = nn.ModuleList([
|
self.vace_blocks = nn.ModuleList([
|
||||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
VaceWanAttentionBlock(
|
||||||
self.cross_attn_norm, self.eps, block_id=i)
|
't2v_cross_attn',
|
||||||
for i in self.vace_layers
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.window_size,
|
||||||
|
self.qk_norm,
|
||||||
|
self.cross_attn_norm,
|
||||||
|
self.eps,
|
||||||
|
block_id=i) for i in self.vace_layers
|
||||||
])
|
])
|
||||||
|
|
||||||
# vace patch embeddings
|
# vace patch embeddings
|
||||||
self.vace_patch_embedding = nn.Conv3d(
|
self.vace_patch_embedding = nn.Conv3d(
|
||||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
self.vace_in_dim,
|
||||||
)
|
self.dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size)
|
||||||
|
|
||||||
def forward_vace(
|
def forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
vace_context,
|
|
||||||
seq_len,
|
|
||||||
kwargs
|
|
||||||
):
|
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -85,11 +88,12 @@ class WanT2V:
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
from .fm_solvers import (
|
||||||
retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from .vace_processor import VaceVideoProcessor
|
from .vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
|
@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
from diffusers.schedulers.scheduling_utils import (
|
||||||
SchedulerMixin,
|
KarrasDiffusionSchedulers,
|
||||||
SchedulerOutput)
|
SchedulerMixin,
|
||||||
|
SchedulerOutput,
|
||||||
|
)
|
||||||
from diffusers.utils import deprecate, is_scipy_available
|
from diffusers.utils import deprecate, is_scipy_available
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
|
||||||
|
@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
from diffusers.schedulers.scheduling_utils import (
|
||||||
SchedulerMixin,
|
KarrasDiffusionSchedulers,
|
||||||
SchedulerOutput)
|
SchedulerMixin,
|
||||||
|
SchedulerOutput,
|
||||||
|
)
|
||||||
from diffusers.utils import deprecate, is_scipy_available
|
from diffusers.utils import deprecate, is_scipy_available
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
|
@ -7,7 +7,7 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Union, List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import dashscope
|
import dashscope
|
||||||
import torch
|
import torch
|
||||||
@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \
|
|||||||
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
||||||
'''Directly output the rewritten English text.'''
|
'''Directly output the rewritten English text.'''
|
||||||
|
|
||||||
|
|
||||||
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写
|
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写
|
||||||
任务要求:
|
任务要求:
|
||||||
1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写
|
1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写
|
||||||
@ -198,8 +197,8 @@ class PromptExpander:
|
|||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.decide_system_prompt(
|
system_prompt = self.decide_system_prompt(
|
||||||
tar_lang=tar_lang,
|
tar_lang=tar_lang,
|
||||||
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
|
multi_images_input=isinstance(image, (list, tuple)) and
|
||||||
)
|
len(image) > 1)
|
||||||
if seed < 0:
|
if seed < 0:
|
||||||
seed = random.randint(0, sys.maxsize)
|
seed = random.randint(0, sys.maxsize)
|
||||||
if image is not None and self.is_vl:
|
if image is not None and self.is_vl:
|
||||||
@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander):
|
|||||||
def extend_with_img(self,
|
def extend_with_img(self,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
|
image: Union[List[Image.Image], List[str], Image.Image,
|
||||||
|
str] = None,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander):
|
|||||||
_image.save(f.name)
|
_image.save(f.name)
|
||||||
image_path = f"file://{f.name}"
|
image_path = f"file://{f.name}"
|
||||||
return image_path
|
return image_path
|
||||||
|
|
||||||
if not isinstance(image, (list, tuple)):
|
if not isinstance(image, (list, tuple)):
|
||||||
image = [image]
|
image = [image]
|
||||||
image_path_list = [ensure_image(_image) for _image in image]
|
image_path_list = [ensure_image(_image) for _image in image]
|
||||||
role_content = [
|
role_content = [{
|
||||||
{"text": prompt},
|
"text": prompt
|
||||||
*[{"image": image_path} for image_path in image_path_list]
|
}, *[{
|
||||||
]
|
"image": image_path
|
||||||
|
} for image_path in image_path_list]]
|
||||||
system_content = [{"text": system_prompt}]
|
system_content = [{"text": system_prompt}]
|
||||||
prompt = f"{prompt}"
|
prompt = f"{prompt}"
|
||||||
messages = [
|
messages = [
|
||||||
@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
|
|
||||||
if self.is_vl:
|
if self.is_vl:
|
||||||
# default: Load the model on the available device(s)
|
# default: Load the model on the available device(s)
|
||||||
from transformers import (AutoProcessor, AutoTokenizer,
|
from transformers import (
|
||||||
Qwen2_5_VLForConditionalGeneration)
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from .qwen_vl_utils import process_vision_info
|
from .qwen_vl_utils import process_vision_info
|
||||||
except:
|
except:
|
||||||
@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
def extend_with_img(self,
|
def extend_with_img(self,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
|
image: Union[List[Image.Image], List[str], Image.Image,
|
||||||
|
str] = None,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
if not isinstance(image, (list, tuple)):
|
if not isinstance(image, (list, tuple)):
|
||||||
image = [image]
|
image = [image]
|
||||||
|
|
||||||
system_content = [{
|
system_content = [{"type": "text", "text": system_prompt}]
|
||||||
|
role_content = [{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": system_prompt
|
"text": prompt
|
||||||
}]
|
}, *[{
|
||||||
role_content = [
|
"image": image_path
|
||||||
{
|
} for image_path in image]]
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
},
|
|
||||||
*[
|
|
||||||
{"image": image_path} for image_path in image
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
'role': 'system',
|
'role': 'system',
|
||||||
'content': system_content,
|
'content': system_content,
|
||||||
}, {
|
}, {
|
||||||
"role":
|
"role": "user",
|
||||||
"user",
|
|
||||||
"content": role_content,
|
"content": role_content,
|
||||||
}]
|
}]
|
||||||
|
|
||||||
@ -611,25 +610,38 @@ if __name__ == "__main__":
|
|||||||
print("VL qwen vl en result -> en",
|
print("VL qwen vl en result -> en",
|
||||||
qwen_result.prompt) # , qwen_result.system_prompt)
|
qwen_result.prompt) # , qwen_result.system_prompt)
|
||||||
# test multi images
|
# test multi images
|
||||||
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
|
image = [
|
||||||
|
"./examples/flf2v_input_first_frame.png",
|
||||||
|
"./examples/flf2v_input_last_frame.png"
|
||||||
|
]
|
||||||
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
|
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
|
||||||
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
|
en_prompt = (
|
||||||
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
|
"Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
|
||||||
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
|
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
|
||||||
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
|
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
|
||||||
|
"architectural structures, combining to create a tranquil and breathtaking coastal landscape."
|
||||||
|
)
|
||||||
|
|
||||||
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
|
dashscope_prompt_expander = DashScopePromptExpander(
|
||||||
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=ds_model_name, is_vl=True)
|
||||||
|
dashscope_result = dashscope_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL dashscope result -> zh", dashscope_result.prompt)
|
print("VL dashscope result -> zh", dashscope_result.prompt)
|
||||||
|
|
||||||
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
|
dashscope_prompt_expander = DashScopePromptExpander(
|
||||||
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=ds_model_name, is_vl=True)
|
||||||
|
dashscope_result = dashscope_prompt_expander(
|
||||||
|
en_prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL dashscope en result -> zh", dashscope_result.prompt)
|
print("VL dashscope en result -> zh", dashscope_result.prompt)
|
||||||
|
|
||||||
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
|
qwen_prompt_expander = QwenPromptExpander(
|
||||||
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=qwen_model_name, is_vl=True, device=0)
|
||||||
|
qwen_result = qwen_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL qwen result -> zh", qwen_result.prompt)
|
print("VL qwen result -> zh", qwen_result.prompt)
|
||||||
|
|
||||||
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
|
qwen_prompt_expander = QwenPromptExpander(
|
||||||
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=qwen_model_name, is_vl=True, device=0)
|
||||||
|
qwen_result = qwen_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL qwen en result -> zh", qwen_result.prompt)
|
print("VL qwen en result -> zh", qwen_result.prompt)
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class VaceImageProcessor(object):
|
class VaceImageProcessor(object):
|
||||||
|
|
||||||
def __init__(self, downsample=None, seq_len=None):
|
def __init__(self, downsample=None, seq_len=None):
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
@ -16,9 +17,10 @@ class VaceImageProcessor(object):
|
|||||||
if image.mode == 'P':
|
if image.mode == 'P':
|
||||||
image = image.convert(f'{cvt_type}A')
|
image = image.convert(f'{cvt_type}A')
|
||||||
if image.mode == f'{cvt_type}A':
|
if image.mode == f'{cvt_type}A':
|
||||||
bg = Image.new(cvt_type,
|
bg = Image.new(
|
||||||
size=(image.width, image.height),
|
cvt_type,
|
||||||
color=(255, 255, 255))
|
size=(image.width, image.height),
|
||||||
|
color=(255, 255, 255))
|
||||||
bg.paste(image, (0, 0), mask=image)
|
bg.paste(image, (0, 0), mask=image)
|
||||||
image = bg
|
image = bg
|
||||||
else:
|
else:
|
||||||
@ -41,10 +43,8 @@ class VaceImageProcessor(object):
|
|||||||
if iw != ow or ih != oh:
|
if iw != ow or ih != oh:
|
||||||
# resize
|
# resize
|
||||||
scale = max(ow / iw, oh / ih)
|
scale = max(ow / iw, oh / ih)
|
||||||
img = img.resize(
|
img = img.resize((round(scale * iw), round(scale * ih)),
|
||||||
(round(scale * iw), round(scale * ih)),
|
resample=Image.Resampling.LANCZOS)
|
||||||
resample=Image.Resampling.LANCZOS
|
|
||||||
)
|
|
||||||
assert img.width >= ow and img.height >= oh
|
assert img.width >= ow and img.height >= oh
|
||||||
|
|
||||||
# center crop
|
# center crop
|
||||||
@ -66,7 +66,11 @@ class VaceImageProcessor(object):
|
|||||||
def load_image_pair(self, data_key, data_key2, **kwargs):
|
def load_image_pair(self, data_key, data_key2, **kwargs):
|
||||||
return self.load_image_batch(data_key, data_key2, **kwargs)
|
return self.load_image_batch(data_key, data_key2, **kwargs)
|
||||||
|
|
||||||
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
def load_image_batch(self,
|
||||||
|
*data_key_batch,
|
||||||
|
normalize=True,
|
||||||
|
seq_len=None,
|
||||||
|
**kwargs):
|
||||||
seq_len = self.seq_len if seq_len is None else seq_len
|
seq_len = self.seq_len if seq_len is None else seq_len
|
||||||
imgs = []
|
imgs = []
|
||||||
for data_key in data_key_batch:
|
for data_key in data_key_batch:
|
||||||
@ -85,7 +89,9 @@ class VaceImageProcessor(object):
|
|||||||
|
|
||||||
|
|
||||||
class VaceVideoProcessor(object):
|
class VaceVideoProcessor(object):
|
||||||
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
|
||||||
|
def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
|
||||||
|
zero_start, seq_len, keep_last, **kwargs):
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.min_area = min_area
|
self.min_area = min_area
|
||||||
self.max_area = max_area
|
self.max_area = max_area
|
||||||
@ -130,8 +136,7 @@ class VaceVideoProcessor(object):
|
|||||||
video,
|
video,
|
||||||
size=(round(scale * ih), round(scale * iw)),
|
size=(round(scale * ih), round(scale * iw)),
|
||||||
mode='bicubic',
|
mode='bicubic',
|
||||||
antialias=True
|
antialias=True)
|
||||||
)
|
|
||||||
assert video.size(3) >= ow and video.size(2) >= oh
|
assert video.size(3) >= ow and video.size(2) >= oh
|
||||||
|
|
||||||
# center crop
|
# center crop
|
||||||
@ -146,7 +151,8 @@ class VaceVideoProcessor(object):
|
|||||||
def _video_preprocess(self, video, oh, ow):
|
def _video_preprocess(self, video, oh, ow):
|
||||||
return self.resize_crop(video, oh, ow)
|
return self.resize_crop(video, oh, ow)
|
||||||
|
|
||||||
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
|
||||||
|
rng):
|
||||||
target_fps = min(fps, self.max_fps)
|
target_fps = min(fps, self.max_fps)
|
||||||
duration = frame_timestamps[-1].mean()
|
duration = frame_timestamps[-1].mean()
|
||||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
@ -154,11 +160,10 @@ class VaceVideoProcessor(object):
|
|||||||
ratio = h / w
|
ratio = h / w
|
||||||
df, dh, dw = self.downsample
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
area_z = min(self.seq_len, self.max_area / (dh * dw),
|
||||||
of = min(
|
(h // dh) * (w // dw))
|
||||||
(int(duration * target_fps) - 1) // df + 1,
|
of = min((int(duration * target_fps) - 1) // df + 1,
|
||||||
int(self.seq_len / area_z)
|
int(self.seq_len / area_z))
|
||||||
)
|
|
||||||
|
|
||||||
# deduce target shape of the [latent video]
|
# deduce target shape of the [latent video]
|
||||||
target_area_z = min(area_z, int(self.seq_len / of))
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
@ -170,26 +175,27 @@ class VaceVideoProcessor(object):
|
|||||||
|
|
||||||
# sample frame ids
|
# sample frame ids
|
||||||
target_duration = of / target_fps
|
target_duration = of / target_fps
|
||||||
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
begin = 0. if self.zero_start else rng.uniform(
|
||||||
|
0, duration - target_duration)
|
||||||
timestamps = np.linspace(begin, begin + target_duration, of)
|
timestamps = np.linspace(begin, begin + target_duration, of)
|
||||||
frame_ids = np.argmax(np.logical_and(
|
frame_ids = np.argmax(
|
||||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
timestamps[:, None] < frame_timestamps[None, :, 1]
|
timestamps[:, None] < frame_timestamps[None, :, 1]),
|
||||||
), axis=1).tolist()
|
axis=1).tolist()
|
||||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
|
||||||
|
crop_box, rng):
|
||||||
duration = frame_timestamps[-1].mean()
|
duration = frame_timestamps[-1].mean()
|
||||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
h, w = y2 - y1, x2 - x1
|
h, w = y2 - y1, x2 - x1
|
||||||
ratio = h / w
|
ratio = h / w
|
||||||
df, dh, dw = self.downsample
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
area_z = min(self.seq_len, self.max_area / (dh * dw),
|
||||||
of = min(
|
(h // dh) * (w // dw))
|
||||||
(len(frame_timestamps) - 1) // df + 1,
|
of = min((len(frame_timestamps) - 1) // df + 1,
|
||||||
int(self.seq_len / area_z)
|
int(self.seq_len / area_z))
|
||||||
)
|
|
||||||
|
|
||||||
# deduce target shape of the [latent video]
|
# deduce target shape of the [latent video]
|
||||||
target_area_z = min(area_z, int(self.seq_len / of))
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
@ -203,27 +209,39 @@ class VaceVideoProcessor(object):
|
|||||||
target_duration = duration
|
target_duration = duration
|
||||||
target_fps = of / target_duration
|
target_fps = of / target_duration
|
||||||
timestamps = np.linspace(0., target_duration, of)
|
timestamps = np.linspace(0., target_duration, of)
|
||||||
frame_ids = np.argmax(np.logical_and(
|
frame_ids = np.argmax(
|
||||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
timestamps[:, None] <= frame_timestamps[None, :, 1]
|
timestamps[:, None] <= frame_timestamps[None, :, 1]),
|
||||||
), axis=1).tolist()
|
axis=1).tolist()
|
||||||
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
||||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
|
|
||||||
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||||
if self.keep_last:
|
if self.keep_last:
|
||||||
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
|
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
|
||||||
|
w, crop_box, rng)
|
||||||
else:
|
else:
|
||||||
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
|
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
|
||||||
|
crop_box, rng)
|
||||||
|
|
||||||
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||||
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
return self.load_video_batch(
|
||||||
|
data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
def load_video_pair(self,
|
||||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
data_key,
|
||||||
|
data_key2,
|
||||||
|
crop_box=None,
|
||||||
|
seed=2024,
|
||||||
|
**kwargs):
|
||||||
|
return self.load_video_batch(
|
||||||
|
data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
|
def load_video_batch(self,
|
||||||
|
*data_key_batch,
|
||||||
|
crop_box=None,
|
||||||
|
seed=2024,
|
||||||
|
**kwargs):
|
||||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||||
# read video
|
# read video
|
||||||
import decord
|
import decord
|
||||||
@ -235,36 +253,53 @@ class VaceVideoProcessor(object):
|
|||||||
|
|
||||||
fps = readers[0].get_avg_fps()
|
fps = readers[0].get_avg_fps()
|
||||||
length = min([len(r) for r in readers])
|
length = min([len(r) for r in readers])
|
||||||
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
frame_timestamps = [
|
||||||
|
readers[0].get_frame_timestamp(i) for i in range(length)
|
||||||
|
]
|
||||||
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||||
h, w = readers[0].next().shape[:2]
|
h, w = readers[0].next().shape[:2]
|
||||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
|
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
|
||||||
|
fps, frame_timestamps, h, w, crop_box, rng)
|
||||||
|
|
||||||
# preprocess video
|
# preprocess video
|
||||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
videos = [
|
||||||
|
reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
|
||||||
|
for reader in readers
|
||||||
|
]
|
||||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||||
return *videos, frame_ids, (oh, ow), fps
|
return *videos, frame_ids, (oh, ow), fps
|
||||||
# return videos if len(videos) > 1 else videos[0]
|
# return videos if len(videos) > 1 else videos[0]
|
||||||
|
|
||||||
|
|
||||||
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
|
||||||
|
device):
|
||||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
if sub_src_video is None and sub_src_mask is None:
|
if sub_src_video is None and sub_src_mask is None:
|
||||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
src_video[i] = torch.zeros(
|
||||||
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
(3, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
|
src_mask[i] = torch.ones(
|
||||||
|
(1, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
for i, ref_images in enumerate(src_ref_images):
|
for i, ref_images in enumerate(src_ref_images):
|
||||||
if ref_images is not None:
|
if ref_images is not None:
|
||||||
for j, ref_img in enumerate(ref_images):
|
for j, ref_img in enumerate(ref_images):
|
||||||
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
||||||
canvas_height, canvas_width = image_size
|
canvas_height, canvas_width = image_size
|
||||||
ref_height, ref_width = ref_img.shape[-2:]
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
white_canvas = torch.ones(
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
(3, 1, canvas_height, canvas_width),
|
||||||
|
device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height,
|
||||||
|
canvas_width / ref_width)
|
||||||
new_height = int(ref_height * scale)
|
new_height = int(ref_height * scale)
|
||||||
new_width = int(ref_width * scale)
|
new_width = int(ref_width * scale)
|
||||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
resized_image = F.interpolate(
|
||||||
|
ref_img.squeeze(1).unsqueeze(0),
|
||||||
|
size=(new_height, new_width),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
top = (canvas_height - new_height) // 2
|
top = (canvas_height - new_height) // 2
|
||||||
left = (canvas_width - new_width) // 2
|
left = (canvas_width - new_width) // 2
|
||||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
white_canvas[:, :, top:top + new_height,
|
||||||
|
left:left + new_width] = resized_image
|
||||||
src_ref_images[i][j] = white_canvas
|
src_ref_images[i][j] = white_canvas
|
||||||
return src_video, src_mask, src_ref_images
|
return src_video, src_mask, src_ref_images
|
||||||
|
273
wan/vace.py
273
wan/vace.py
@ -1,32 +1,41 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import types
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
|
||||||
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
|
||||||
from .modules.vace_model import VaceWanModel
|
from .modules.vace_model import VaceWanModel
|
||||||
|
from .text2video import (
|
||||||
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
FlowUniPCMultistepScheduler,
|
||||||
|
T5EncoderModel,
|
||||||
|
WanT2V,
|
||||||
|
WanVAE,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
shard_model,
|
||||||
|
)
|
||||||
from .utils.vace_processor import VaceVideoProcessor
|
from .utils.vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
class WanVace(WanT2V):
|
class WanVace(WanT2V):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@ -87,12 +96,13 @@ class WanVace(WanT2V):
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward,
|
usp_attn_forward,
|
||||||
usp_dit_forward_vace)
|
usp_dit_forward,
|
||||||
|
usp_dit_forward_vace,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -100,7 +110,8 @@ class WanVace(WanT2V):
|
|||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
||||||
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
|
self.model.forward_vace = types.MethodType(usp_dit_forward_vace,
|
||||||
|
self.model)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
else:
|
else:
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
@ -114,14 +125,16 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
self.vid_proc = VaceVideoProcessor(
|
||||||
min_area=720*1280,
|
downsample=tuple(
|
||||||
max_area=720*1280,
|
[x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||||
min_fps=config.sample_fps,
|
min_area=720 * 1280,
|
||||||
max_fps=config.sample_fps,
|
max_area=720 * 1280,
|
||||||
zero_start=True,
|
min_fps=config.sample_fps,
|
||||||
seq_len=75600,
|
max_fps=config.sample_fps,
|
||||||
keep_last=True)
|
zero_start=True,
|
||||||
|
seq_len=75600,
|
||||||
|
keep_last=True)
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||||
vae = self.vae if vae is None else vae
|
vae = self.vae if vae is None else vae
|
||||||
@ -138,7 +151,9 @@ class WanVace(WanT2V):
|
|||||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||||
inactive = vae.encode(inactive)
|
inactive = vae.encode(inactive)
|
||||||
reactive = vae.encode(reactive)
|
reactive = vae.encode(reactive)
|
||||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
latents = [
|
||||||
|
torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)
|
||||||
|
]
|
||||||
|
|
||||||
cat_latents = []
|
cat_latents = []
|
||||||
for latent, refs in zip(latents, ref_images):
|
for latent, refs in zip(latents, ref_images):
|
||||||
@ -147,7 +162,10 @@ class WanVace(WanT2V):
|
|||||||
ref_latent = vae.encode(refs)
|
ref_latent = vae.encode(refs)
|
||||||
else:
|
else:
|
||||||
ref_latent = vae.encode(refs)
|
ref_latent = vae.encode(refs)
|
||||||
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
ref_latent = [
|
||||||
|
torch.cat((u, torch.zeros_like(u)), dim=0)
|
||||||
|
for u in ref_latent
|
||||||
|
]
|
||||||
assert all([x.shape[1] == 1 for x in ref_latent])
|
assert all([x.shape[1] == 1 for x in ref_latent])
|
||||||
latent = torch.cat([*ref_latent, latent], dim=1)
|
latent = torch.cat([*ref_latent, latent], dim=1)
|
||||||
cat_latents.append(latent)
|
cat_latents.append(latent)
|
||||||
@ -169,16 +187,17 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
# reshape
|
# reshape
|
||||||
mask = mask[0, :, :, :]
|
mask = mask[0, :, :, :]
|
||||||
mask = mask.view(
|
mask = mask.view(depth, height, vae_stride[1], width,
|
||||||
depth, height, vae_stride[1], width, vae_stride[1]
|
vae_stride[1]) # depth, height, 8, width, 8
|
||||||
) # depth, height, 8, width, 8
|
|
||||||
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
||||||
mask = mask.reshape(
|
mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height,
|
||||||
vae_stride[1] * vae_stride[2], depth, height, width
|
width) # 8*8, depth, height, width
|
||||||
) # 8*8, depth, height, width
|
|
||||||
|
|
||||||
# interpolation
|
# interpolation
|
||||||
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
mask = F.interpolate(
|
||||||
|
mask.unsqueeze(0),
|
||||||
|
size=(new_depth, height, width),
|
||||||
|
mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
if refs is not None:
|
if refs is not None:
|
||||||
length = len(refs)
|
length = len(refs)
|
||||||
@ -190,27 +209,35 @@ class WanVace(WanT2V):
|
|||||||
def vace_latent(self, z, m):
|
def vace_latent(self, z, m):
|
||||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||||
|
|
||||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
|
||||||
|
image_size, device):
|
||||||
area = image_size[0] * image_size[1]
|
area = image_size[0] * image_size[1]
|
||||||
self.vid_proc.set_area(area)
|
self.vid_proc.set_area(area)
|
||||||
if area == 720*1280:
|
if area == 720 * 1280:
|
||||||
self.vid_proc.set_seq_len(75600)
|
self.vid_proc.set_seq_len(75600)
|
||||||
elif area == 480*832:
|
elif area == 480 * 832:
|
||||||
self.vid_proc.set_seq_len(32760)
|
self.vid_proc.set_seq_len(32760)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'image_size {image_size} is not supported')
|
raise NotImplementedError(
|
||||||
|
f'image_size {image_size} is not supported')
|
||||||
|
|
||||||
image_size = (image_size[1], image_size[0])
|
image_size = (image_size[1], image_size[0])
|
||||||
image_sizes = []
|
image_sizes = []
|
||||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
for i, (sub_src_video,
|
||||||
|
sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
if sub_src_mask is not None and sub_src_video is not None:
|
if sub_src_mask is not None and sub_src_video is not None:
|
||||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
|
src_video[i], src_mask[
|
||||||
|
i], _, _, _ = self.vid_proc.load_video_pair(
|
||||||
|
sub_src_video, sub_src_mask)
|
||||||
src_video[i] = src_video[i].to(device)
|
src_video[i] = src_video[i].to(device)
|
||||||
src_mask[i] = src_mask[i].to(device)
|
src_mask[i] = src_mask[i].to(device)
|
||||||
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
src_mask[i] = torch.clamp(
|
||||||
|
(src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||||
image_sizes.append(src_video[i].shape[2:])
|
image_sizes.append(src_video[i].shape[2:])
|
||||||
elif sub_src_video is None:
|
elif sub_src_video is None:
|
||||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
src_video[i] = torch.zeros(
|
||||||
|
(3, num_frames, image_size[0], image_size[1]),
|
||||||
|
device=device)
|
||||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||||
image_sizes.append(image_size)
|
image_sizes.append(image_size)
|
||||||
else:
|
else:
|
||||||
@ -225,18 +252,27 @@ class WanVace(WanT2V):
|
|||||||
for j, ref_img in enumerate(ref_images):
|
for j, ref_img in enumerate(ref_images):
|
||||||
if ref_img is not None:
|
if ref_img is not None:
|
||||||
ref_img = Image.open(ref_img).convert("RGB")
|
ref_img = Image.open(ref_img).convert("RGB")
|
||||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(
|
||||||
|
0.5).unsqueeze(1)
|
||||||
if ref_img.shape[-2:] != image_size:
|
if ref_img.shape[-2:] != image_size:
|
||||||
canvas_height, canvas_width = image_size
|
canvas_height, canvas_width = image_size
|
||||||
ref_height, ref_width = ref_img.shape[-2:]
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
white_canvas = torch.ones(
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
(3, 1, canvas_height, canvas_width),
|
||||||
|
device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height,
|
||||||
|
canvas_width / ref_width)
|
||||||
new_height = int(ref_height * scale)
|
new_height = int(ref_height * scale)
|
||||||
new_width = int(ref_width * scale)
|
new_width = int(ref_width * scale)
|
||||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
resized_image = F.interpolate(
|
||||||
|
ref_img.squeeze(1).unsqueeze(0),
|
||||||
|
size=(new_height, new_width),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
top = (canvas_height - new_height) // 2
|
top = (canvas_height - new_height) // 2
|
||||||
left = (canvas_width - new_width) // 2
|
left = (canvas_width - new_width) // 2
|
||||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
white_canvas[:, :, top:top + new_height,
|
||||||
|
left:left + new_width] = resized_image
|
||||||
ref_img = white_canvas
|
ref_img = white_canvas
|
||||||
src_ref_images[i][j] = ref_img.to(device)
|
src_ref_images[i][j] = ref_img.to(device)
|
||||||
return src_video, src_mask, src_ref_images
|
return src_video, src_mask, src_ref_images
|
||||||
@ -256,8 +292,6 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
return vae.decode(trimed_zs)
|
return vae.decode(trimed_zs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames,
|
input_frames,
|
||||||
@ -335,7 +369,8 @@ class WanVace(WanT2V):
|
|||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
# vace context encode
|
# vace context encode
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
|
z0 = self.vace_encode_frames(
|
||||||
|
input_frames, input_ref_images, masks=input_masks)
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
z = self.vace_latent(z0, m0)
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
@ -399,9 +434,17 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_c)[0]
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_null)[0]
|
||||||
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond)
|
noise_pred_cond - noise_pred_uncond)
|
||||||
@ -433,14 +476,13 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
|
|
||||||
class WanVaceMP(WanVace):
|
class WanVaceMP(WanVace):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
ulysses_size=None,
|
ulysses_size=None,
|
||||||
ring_size=None
|
ring_size=None):
|
||||||
):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
@ -457,7 +499,8 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
||||||
self.vid_proc = VaceVideoProcessor(
|
self.vid_proc = VaceVideoProcessor(
|
||||||
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
downsample=tuple(
|
||||||
|
[x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
||||||
min_area=480 * 832,
|
min_area=480 * 832,
|
||||||
max_area=480 * 832,
|
max_area=480 * 832,
|
||||||
min_fps=self.config.sample_fps,
|
min_fps=self.config.sample_fps,
|
||||||
@ -466,20 +509,30 @@ class WanVaceMP(WanVace):
|
|||||||
seq_len=32760,
|
seq_len=32760,
|
||||||
keep_last=True)
|
keep_last=True)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_load(self):
|
def dynamic_load(self):
|
||||||
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
||||||
return
|
return
|
||||||
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
gpu_infer = os.environ.get(
|
||||||
|
'LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
||||||
pmi_rank = int(os.environ['RANK'])
|
pmi_rank = int(os.environ['RANK'])
|
||||||
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
||||||
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
|
in_q_list = [
|
||||||
|
torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)
|
||||||
|
]
|
||||||
out_q = torch.multiprocessing.Manager().Queue()
|
out_q = torch.multiprocessing.Manager().Queue()
|
||||||
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
|
initialized_events = [
|
||||||
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
|
torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)
|
||||||
|
]
|
||||||
|
context = mp.spawn(
|
||||||
|
self.mp_worker,
|
||||||
|
nprocs=gpu_infer,
|
||||||
|
args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q,
|
||||||
|
initialized_events, self),
|
||||||
|
join=False)
|
||||||
all_initialized = False
|
all_initialized = False
|
||||||
while not all_initialized:
|
while not all_initialized:
|
||||||
all_initialized = all(event.is_set() for event in initialized_events)
|
all_initialized = all(
|
||||||
|
event.is_set() for event in initialized_events)
|
||||||
if not all_initialized:
|
if not all_initialized:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
print('Inference model is initialized', flush=True)
|
print('Inference model is initialized', flush=True)
|
||||||
@ -495,12 +548,19 @@ class WanVaceMP(WanVace):
|
|||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
|
data = [
|
||||||
|
self.transfer_data_to_cuda(subdata, device)
|
||||||
|
for subdata in data
|
||||||
|
]
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
|
data = {
|
||||||
|
key: self.transfer_data_to_cuda(val, device)
|
||||||
|
for key, val in data.items()
|
||||||
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
|
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
|
||||||
|
out_q, initialized_events, work_env):
|
||||||
try:
|
try:
|
||||||
world_size = pmi_world_size * gpu_infer
|
world_size = pmi_world_size * gpu_infer
|
||||||
rank = pmi_rank * gpu_infer + gpu
|
rank = pmi_rank * gpu_infer + gpu
|
||||||
@ -511,19 +571,19 @@ class WanVaceMP(WanVace):
|
|||||||
backend='nccl',
|
backend='nccl',
|
||||||
init_method='env://',
|
init_method='env://',
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size
|
world_size=world_size)
|
||||||
)
|
|
||||||
|
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
from xfuser.core.distributed import (
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
initialize_model_parallel(
|
initialize_model_parallel(
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
ring_degree=self.ring_size or 1,
|
ring_degree=self.ring_size or 1,
|
||||||
ulysses_degree=self.ulysses_size or 1
|
ulysses_degree=self.ulysses_size or 1)
|
||||||
)
|
|
||||||
|
|
||||||
num_train_timesteps = self.config.num_train_timesteps
|
num_train_timesteps = self.config.num_train_timesteps
|
||||||
param_dtype = self.config.param_dtype
|
param_dtype = self.config.param_dtype
|
||||||
@ -532,14 +592,17 @@ class WanVaceMP(WanVace):
|
|||||||
text_len=self.config.text_len,
|
text_len=self.config.text_len,
|
||||||
dtype=self.config.t5_dtype,
|
dtype=self.config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
|
checkpoint_path=os.path.join(self.checkpoint_dir,
|
||||||
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
|
self.config.t5_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(self.checkpoint_dir,
|
||||||
|
self.config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if True else None)
|
shard_fn=shard_fn if True else None)
|
||||||
text_encoder.model.to(gpu)
|
text_encoder.model.to(gpu)
|
||||||
vae_stride = self.config.vae_stride
|
vae_stride = self.config.vae_stride
|
||||||
patch_size = self.config.patch_size
|
patch_size = self.config.patch_size
|
||||||
vae = WanVAE(
|
vae = WanVAE(
|
||||||
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
|
vae_pth=os.path.join(self.checkpoint_dir,
|
||||||
|
self.config.vae_checkpoint),
|
||||||
device=gpu)
|
device=gpu)
|
||||||
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
||||||
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
||||||
@ -547,9 +610,12 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
if self.use_usp:
|
if self.use_usp:
|
||||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
|
||||||
usp_dit_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward_vace)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
usp_dit_forward_vace,
|
||||||
|
)
|
||||||
for block in model.blocks:
|
for block in model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -557,7 +623,8 @@ class WanVaceMP(WanVace):
|
|||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
model.forward = types.MethodType(usp_dit_forward, model)
|
model.forward = types.MethodType(usp_dit_forward, model)
|
||||||
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
|
model.forward_vace = types.MethodType(usp_dit_forward_vace,
|
||||||
|
model)
|
||||||
sp_size = get_sequence_parallel_world_size()
|
sp_size = get_sequence_parallel_world_size()
|
||||||
else:
|
else:
|
||||||
sp_size = 1
|
sp_size = 1
|
||||||
@ -577,7 +644,8 @@ class WanVaceMP(WanVace):
|
|||||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
||||||
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
||||||
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
||||||
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
|
input_ref_images = self.transfer_data_to_cuda(
|
||||||
|
input_ref_images, gpu)
|
||||||
|
|
||||||
if n_prompt == "":
|
if n_prompt == "":
|
||||||
n_prompt = sample_neg_prompt
|
n_prompt = sample_neg_prompt
|
||||||
@ -589,8 +657,10 @@ class WanVaceMP(WanVace):
|
|||||||
context_null = text_encoder([n_prompt], gpu)
|
context_null = text_encoder([n_prompt], gpu)
|
||||||
|
|
||||||
# vace context encode
|
# vace context encode
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
|
z0 = self.vace_encode_frames(
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
|
input_frames, input_ref_images, masks=input_masks, vae=vae)
|
||||||
|
m0 = self.vace_encode_masks(
|
||||||
|
input_masks, input_ref_images, vae_stride=vae_stride)
|
||||||
z = self.vace_latent(z0, m0)
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
target_shape = list(z0[0].shape)
|
target_shape = list(z0[0].shape)
|
||||||
@ -616,7 +686,8 @@ class WanVaceMP(WanVace):
|
|||||||
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
||||||
|
|
||||||
# evaluation mode
|
# evaluation mode
|
||||||
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
with amp.autocast(
|
||||||
|
dtype=param_dtype), torch.no_grad(), no_sync():
|
||||||
|
|
||||||
if sample_solver == 'unipc':
|
if sample_solver == 'unipc':
|
||||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||||
@ -631,7 +702,8 @@ class WanVaceMP(WanVace):
|
|||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
shift=1,
|
shift=1,
|
||||||
use_dynamic_shifting=False)
|
use_dynamic_shifting=False)
|
||||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
sampling_sigmas = get_sampling_sigmas(
|
||||||
|
sampling_steps, shift)
|
||||||
timesteps, _ = retrieve_timesteps(
|
timesteps, _ = retrieve_timesteps(
|
||||||
sample_scheduler,
|
sample_scheduler,
|
||||||
device=gpu,
|
device=gpu,
|
||||||
@ -653,14 +725,20 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
model.to(gpu)
|
model.to(gpu)
|
||||||
noise_pred_cond = model(
|
noise_pred_cond = model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
|
latent_model_input,
|
||||||
0]
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_c)[0]
|
||||||
noise_pred_uncond = model(
|
noise_pred_uncond = model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
**arg_null)[0]
|
**arg_null)[0]
|
||||||
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond)
|
noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
temp_x0 = sample_scheduler.step(
|
temp_x0 = sample_scheduler.step(
|
||||||
noise_pred.unsqueeze(0),
|
noise_pred.unsqueeze(0),
|
||||||
@ -673,7 +751,8 @@ class WanVaceMP(WanVace):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
videos = self.decode_latent(x0, input_ref_images, vae=vae)
|
videos = self.decode_latent(
|
||||||
|
x0, input_ref_images, vae=vae)
|
||||||
|
|
||||||
del noise, latents
|
del noise, latents
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
@ -691,8 +770,6 @@ class WanVaceMP(WanVace):
|
|||||||
print(trace_info, flush=True)
|
print(trace_info, flush=True)
|
||||||
print(e, flush=True)
|
print(e, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames,
|
input_frames,
|
||||||
@ -709,8 +786,10 @@ class WanVaceMP(WanVace):
|
|||||||
seed=-1,
|
seed=-1,
|
||||||
offload_model=True):
|
offload_model=True):
|
||||||
|
|
||||||
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
|
input_data = (input_prompt, input_frames, input_masks, input_ref_images,
|
||||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
|
size, frame_num, context_scale, shift, sample_solver,
|
||||||
|
sampling_steps, guide_scale, n_prompt, seed,
|
||||||
|
offload_model)
|
||||||
for in_q in self.in_q_list:
|
for in_q in self.in_q_list:
|
||||||
in_q.put(input_data)
|
in_q.put(input_data)
|
||||||
value_output = self.out_q.get()
|
value_output = self.out_q.get()
|
||||||
|
Loading…
Reference in New Issue
Block a user