mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-17 21:07:41 +00:00
Compare commits
5 Commits
55fdfe155d
...
1c12e4d958
Author | SHA1 | Date | |
---|---|---|---|
|
1c12e4d958 | ||
|
e5a741309d | ||
|
76e9427657 | ||
|
c6c5675a06 | ||
|
0961b7b888 |
393
.style.yapf
Normal file
393
.style.yapf
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
[style]
|
||||||
|
# Align closing bracket with visual indentation.
|
||||||
|
align_closing_bracket_with_visual_indent=False
|
||||||
|
|
||||||
|
# Allow dictionary keys to exist on multiple lines. For example:
|
||||||
|
#
|
||||||
|
# x = {
|
||||||
|
# ('this is the first element of a tuple',
|
||||||
|
# 'this is the second element of a tuple'):
|
||||||
|
# value,
|
||||||
|
# }
|
||||||
|
allow_multiline_dictionary_keys=False
|
||||||
|
|
||||||
|
# Allow lambdas to be formatted on more than one line.
|
||||||
|
allow_multiline_lambdas=False
|
||||||
|
|
||||||
|
# Allow splitting before a default / named assignment in an argument list.
|
||||||
|
allow_split_before_default_or_named_assigns=False
|
||||||
|
|
||||||
|
# Allow splits before the dictionary value.
|
||||||
|
allow_split_before_dict_value=True
|
||||||
|
|
||||||
|
# Let spacing indicate operator precedence. For example:
|
||||||
|
#
|
||||||
|
# a = 1 * 2 + 3 / 4
|
||||||
|
# b = 1 / 2 - 3 * 4
|
||||||
|
# c = (1 + 2) * (3 - 4)
|
||||||
|
# d = (1 - 2) / (3 + 4)
|
||||||
|
# e = 1 * 2 - 3
|
||||||
|
# f = 1 + 2 + 3 + 4
|
||||||
|
#
|
||||||
|
# will be formatted as follows to indicate precedence:
|
||||||
|
#
|
||||||
|
# a = 1*2 + 3/4
|
||||||
|
# b = 1/2 - 3*4
|
||||||
|
# c = (1+2) * (3-4)
|
||||||
|
# d = (1-2) / (3+4)
|
||||||
|
# e = 1*2 - 3
|
||||||
|
# f = 1 + 2 + 3 + 4
|
||||||
|
#
|
||||||
|
arithmetic_precedence_indication=False
|
||||||
|
|
||||||
|
# Number of blank lines surrounding top-level function and class
|
||||||
|
# definitions.
|
||||||
|
blank_lines_around_top_level_definition=2
|
||||||
|
|
||||||
|
# Insert a blank line before a class-level docstring.
|
||||||
|
blank_line_before_class_docstring=False
|
||||||
|
|
||||||
|
# Insert a blank line before a module docstring.
|
||||||
|
blank_line_before_module_docstring=False
|
||||||
|
|
||||||
|
# Insert a blank line before a 'def' or 'class' immediately nested
|
||||||
|
# within another 'def' or 'class'. For example:
|
||||||
|
#
|
||||||
|
# class Foo:
|
||||||
|
# # <------ this blank line
|
||||||
|
# def method():
|
||||||
|
# ...
|
||||||
|
blank_line_before_nested_class_or_def=True
|
||||||
|
|
||||||
|
# Do not split consecutive brackets. Only relevant when
|
||||||
|
# dedent_closing_brackets is set. For example:
|
||||||
|
#
|
||||||
|
# call_func_that_takes_a_dict(
|
||||||
|
# {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# would reformat to:
|
||||||
|
#
|
||||||
|
# call_func_that_takes_a_dict({
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# })
|
||||||
|
coalesce_brackets=False
|
||||||
|
|
||||||
|
# The column limit.
|
||||||
|
column_limit=80
|
||||||
|
|
||||||
|
# The style for continuation alignment. Possible values are:
|
||||||
|
#
|
||||||
|
# - SPACE: Use spaces for continuation alignment. This is default behavior.
|
||||||
|
# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
|
||||||
|
# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
|
||||||
|
# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
|
||||||
|
# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
|
||||||
|
# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
|
||||||
|
# cannot vertically align continuation lines with indent characters.
|
||||||
|
continuation_align_style=SPACE
|
||||||
|
|
||||||
|
# Indent width used for line continuations.
|
||||||
|
continuation_indent_width=4
|
||||||
|
|
||||||
|
# Put closing brackets on a separate line, dedented, if the bracketed
|
||||||
|
# expression can't fit in a single line. Applies to all kinds of brackets,
|
||||||
|
# including function definitions and calls. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# } # <--- this bracket is dedented and on a separate line
|
||||||
|
#
|
||||||
|
# time_series = self.remote_client.query_entity_counters(
|
||||||
|
# entity='dev3246.region1',
|
||||||
|
# key='dns.query_latency_tcp',
|
||||||
|
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
|
||||||
|
# start_ts=now()-timedelta(days=3),
|
||||||
|
# end_ts=now(),
|
||||||
|
# ) # <--- this bracket is dedented and on a separate line
|
||||||
|
dedent_closing_brackets=False
|
||||||
|
|
||||||
|
# Disable the heuristic which places each list element on a separate line
|
||||||
|
# if the list is comma-terminated.
|
||||||
|
disable_ending_comma_heuristic=False
|
||||||
|
|
||||||
|
# Place each dictionary entry onto its own line.
|
||||||
|
each_dict_entry_on_separate_line=True
|
||||||
|
|
||||||
|
# Require multiline dictionary even if it would normally fit on one line.
|
||||||
|
# For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1'
|
||||||
|
# }
|
||||||
|
force_multiline_dict=False
|
||||||
|
|
||||||
|
# The regex for an i18n comment. The presence of this comment stops
|
||||||
|
# reformatting of that line, because the comments are required to be
|
||||||
|
# next to the string they translate.
|
||||||
|
i18n_comment=#\..*
|
||||||
|
|
||||||
|
# The i18n function call names. The presence of this function stops
|
||||||
|
# reformattting on that line, because the string it has cannot be moved
|
||||||
|
# away from the i18n comment.
|
||||||
|
i18n_function_call=N_, _
|
||||||
|
|
||||||
|
# Indent blank lines.
|
||||||
|
indent_blank_lines=False
|
||||||
|
|
||||||
|
# Put closing brackets on a separate line, indented, if the bracketed
|
||||||
|
# expression can't fit in a single line. Applies to all kinds of brackets,
|
||||||
|
# including function definitions and calls. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1': 'value1',
|
||||||
|
# 'key2': 'value2',
|
||||||
|
# } # <--- this bracket is indented and on a separate line
|
||||||
|
#
|
||||||
|
# time_series = self.remote_client.query_entity_counters(
|
||||||
|
# entity='dev3246.region1',
|
||||||
|
# key='dns.query_latency_tcp',
|
||||||
|
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
|
||||||
|
# start_ts=now()-timedelta(days=3),
|
||||||
|
# end_ts=now(),
|
||||||
|
# ) # <--- this bracket is indented and on a separate line
|
||||||
|
indent_closing_brackets=False
|
||||||
|
|
||||||
|
# Indent the dictionary value if it cannot fit on the same line as the
|
||||||
|
# dictionary key. For example:
|
||||||
|
#
|
||||||
|
# config = {
|
||||||
|
# 'key1':
|
||||||
|
# 'value1',
|
||||||
|
# 'key2': value1 +
|
||||||
|
# value2,
|
||||||
|
# }
|
||||||
|
indent_dictionary_value=True
|
||||||
|
|
||||||
|
# The number of columns to use for indentation.
|
||||||
|
indent_width=4
|
||||||
|
|
||||||
|
# Join short lines into one line. E.g., single line 'if' statements.
|
||||||
|
join_multiple_lines=False
|
||||||
|
|
||||||
|
# Do not include spaces around selected binary operators. For example:
|
||||||
|
#
|
||||||
|
# 1 + 2 * 3 - 4 / 5
|
||||||
|
#
|
||||||
|
# will be formatted as follows when configured with "*,/":
|
||||||
|
#
|
||||||
|
# 1 + 2*3 - 4/5
|
||||||
|
no_spaces_around_selected_binary_operators=
|
||||||
|
|
||||||
|
# Use spaces around default or named assigns.
|
||||||
|
spaces_around_default_or_named_assign=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '{' and before the ending '}' dict delimiters.
|
||||||
|
#
|
||||||
|
# {1: 2}
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# { 1: 2 }
|
||||||
|
spaces_around_dict_delimiters=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '[' and before the ending ']' list delimiters.
|
||||||
|
#
|
||||||
|
# [1, 2]
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# [ 1, 2 ]
|
||||||
|
spaces_around_list_delimiters=False
|
||||||
|
|
||||||
|
# Use spaces around the power operator.
|
||||||
|
spaces_around_power_operator=False
|
||||||
|
|
||||||
|
# Use spaces around the subscript / slice operator. For example:
|
||||||
|
#
|
||||||
|
# my_list[1 : 10 : 2]
|
||||||
|
spaces_around_subscript_colon=False
|
||||||
|
|
||||||
|
# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
|
||||||
|
#
|
||||||
|
# (1, 2, 3)
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# ( 1, 2, 3 )
|
||||||
|
spaces_around_tuple_delimiters=False
|
||||||
|
|
||||||
|
# The number of spaces required before a trailing comment.
|
||||||
|
# This can be a single value (representing the number of spaces
|
||||||
|
# before each trailing comment) or list of values (representing
|
||||||
|
# alignment column values; trailing comments within a block will
|
||||||
|
# be aligned to the first column value that is greater than the maximum
|
||||||
|
# line length within the block). For example:
|
||||||
|
#
|
||||||
|
# With spaces_before_comment=5:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
|
||||||
|
#
|
||||||
|
# With spaces_before_comment=15, 20:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values
|
||||||
|
# two + two # More adding
|
||||||
|
#
|
||||||
|
# longer_statement # This is a longer statement
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# a_very_long_statement_that_extends_beyond_the_final_column # Comment
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# will be formatted as:
|
||||||
|
#
|
||||||
|
# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
|
||||||
|
# two + two # More adding
|
||||||
|
#
|
||||||
|
# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
|
||||||
|
# short # This is a shorter statement
|
||||||
|
#
|
||||||
|
spaces_before_comment=2
|
||||||
|
|
||||||
|
# Insert a space between the ending comma and closing bracket of a list,
|
||||||
|
# etc.
|
||||||
|
space_between_ending_comma_and_closing_bracket=False
|
||||||
|
|
||||||
|
# Use spaces inside brackets, braces, and parentheses. For example:
|
||||||
|
#
|
||||||
|
# method_call( 1 )
|
||||||
|
# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
|
||||||
|
# my_set = { 1, 2, 3 }
|
||||||
|
space_inside_brackets=False
|
||||||
|
|
||||||
|
# Split before arguments
|
||||||
|
split_all_comma_separated_values=False
|
||||||
|
|
||||||
|
# Split before arguments, but do not split all subexpressions recursively
|
||||||
|
# (unless needed).
|
||||||
|
split_all_top_level_comma_separated_values=False
|
||||||
|
|
||||||
|
# Split before arguments if the argument list is terminated by a
|
||||||
|
# comma.
|
||||||
|
split_arguments_when_comma_terminated=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
|
||||||
|
# rather than after.
|
||||||
|
split_before_arithmetic_operator=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before '&', '|' or '^' rather than
|
||||||
|
# after.
|
||||||
|
split_before_bitwise_operator=False
|
||||||
|
|
||||||
|
# Split before the closing bracket if a list or dict literal doesn't fit on
|
||||||
|
# a single line.
|
||||||
|
split_before_closing_bracket=True
|
||||||
|
|
||||||
|
# Split before a dictionary or set generator (comp_for). For example, note
|
||||||
|
# the split before the 'for':
|
||||||
|
#
|
||||||
|
# foo = {
|
||||||
|
# variable: 'Hello world, have a nice day!'
|
||||||
|
# for variable in bar if variable != 42
|
||||||
|
# }
|
||||||
|
split_before_dict_set_generator=False
|
||||||
|
|
||||||
|
# Split before the '.' if we need to split a longer expression:
|
||||||
|
#
|
||||||
|
# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# foo = ('This is a really long string: {}, {}, {}, {}'
|
||||||
|
# .format(a, b, c, d))
|
||||||
|
split_before_dot=False
|
||||||
|
|
||||||
|
# Split after the opening paren which surrounds an expression if it doesn't
|
||||||
|
# fit on a single line.
|
||||||
|
split_before_expression_after_opening_paren=True
|
||||||
|
|
||||||
|
# If an argument / parameter list is going to be split, then split before
|
||||||
|
# the first argument.
|
||||||
|
split_before_first_argument=False
|
||||||
|
|
||||||
|
# Set to True to prefer splitting before 'and' or 'or' rather than
|
||||||
|
# after.
|
||||||
|
split_before_logical_operator=False
|
||||||
|
|
||||||
|
# Split named assignments onto individual lines.
|
||||||
|
split_before_named_assigns=True
|
||||||
|
|
||||||
|
# Set to True to split list comprehensions and generators that have
|
||||||
|
# non-trivial expressions and multiple clauses before each of these
|
||||||
|
# clauses. For example:
|
||||||
|
#
|
||||||
|
# result = [
|
||||||
|
# a_long_var + 100 for a_long_var in xrange(1000)
|
||||||
|
# if a_long_var % 10]
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# result = [
|
||||||
|
# a_long_var + 100
|
||||||
|
# for a_long_var in xrange(1000)
|
||||||
|
# if a_long_var % 10]
|
||||||
|
split_complex_comprehension=True
|
||||||
|
|
||||||
|
# The penalty for splitting right after the opening bracket.
|
||||||
|
split_penalty_after_opening_bracket=300
|
||||||
|
|
||||||
|
# The penalty for splitting the line after a unary operator.
|
||||||
|
split_penalty_after_unary_operator=10000
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the '+', '-', '*', '/', '//',
|
||||||
|
# ``%``, and '@' operators.
|
||||||
|
split_penalty_arithmetic_operator=300
|
||||||
|
|
||||||
|
# The penalty for splitting right before an if expression.
|
||||||
|
split_penalty_before_if_expr=0
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the '&', '|', and '^'
|
||||||
|
# operators.
|
||||||
|
split_penalty_bitwise_operator=300
|
||||||
|
|
||||||
|
# The penalty for splitting a list comprehension or generator
|
||||||
|
# expression.
|
||||||
|
split_penalty_comprehension=2100
|
||||||
|
|
||||||
|
# The penalty for characters over the column limit.
|
||||||
|
split_penalty_excess_character=7000
|
||||||
|
|
||||||
|
# The penalty incurred by adding a line split to the unwrapped line. The
|
||||||
|
# more line splits added the higher the penalty.
|
||||||
|
split_penalty_for_added_line_split=30
|
||||||
|
|
||||||
|
# The penalty of splitting a list of "import as" names. For example:
|
||||||
|
#
|
||||||
|
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
|
||||||
|
# long_argument_2,
|
||||||
|
# long_argument_3)
|
||||||
|
#
|
||||||
|
# would reformat to something like:
|
||||||
|
#
|
||||||
|
# from a_very_long_or_indented_module_name_yada_yad import (
|
||||||
|
# long_argument_1, long_argument_2, long_argument_3)
|
||||||
|
split_penalty_import_names=0
|
||||||
|
|
||||||
|
# The penalty of splitting the line around the 'and' and 'or'
|
||||||
|
# operators.
|
||||||
|
split_penalty_logical_operator=300
|
||||||
|
|
||||||
|
# Use the Tab character for indentation.
|
||||||
|
use_tabs=False
|
151
I2V-FastAPI文档.md
Normal file
151
I2V-FastAPI文档.md
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
|
||||||
|
# 图像到视频生成服务API文档
|
||||||
|
|
||||||
|
## 一、功能概述
|
||||||
|
基于Wan2.1-I2V-14B-480P模型实现图像到视频生成,核心功能包括:
|
||||||
|
1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
|
||||||
|
2. **智能分辨率适配**:
|
||||||
|
- 支持自动计算最佳分辨率(保持原图比例)
|
||||||
|
- 支持手动指定分辨率(480x832/832x480)
|
||||||
|
3. **资源管理**:
|
||||||
|
- 显存优化(bfloat16精度)
|
||||||
|
- 生成文件自动清理(默认1小时)
|
||||||
|
4. **安全认证**:基于API Key的Bearer Token验证
|
||||||
|
5. **任务控制**:支持任务提交/状态查询/取消操作
|
||||||
|
|
||||||
|
技术栈:
|
||||||
|
- FastAPI框架
|
||||||
|
- CUDA加速
|
||||||
|
- 异步任务处理
|
||||||
|
- Diffusers推理库
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、接口说明
|
||||||
|
|
||||||
|
### 1. 提交生成任务
|
||||||
|
**POST /video/submit**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "Wan2.1-I2V-14B-480P",
|
||||||
|
"prompt": "A dancing cat in the style of Van Gogh",
|
||||||
|
"image_url": "https://example.com/input.jpg",
|
||||||
|
"image_size": "auto",
|
||||||
|
"num_frames": 81,
|
||||||
|
"guidance_scale": 3.0,
|
||||||
|
"infer_steps": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 查询任务状态
|
||||||
|
**POST /video/status**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed",
|
||||||
|
"results": {
|
||||||
|
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
|
||||||
|
"timings": {"inference": 90},
|
||||||
|
"seed": 123456
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 取消任务
|
||||||
|
**POST /video/cancel**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、Postman使用指南
|
||||||
|
|
||||||
|
### 1. 基础配置
|
||||||
|
- 服务器地址:`http://ip地址:8088`
|
||||||
|
- 认证方式:Bearer Token
|
||||||
|
- Token值:需替换为有效API Key
|
||||||
|
|
||||||
|
### 2. 提交任务
|
||||||
|
1. 选择POST方法,URL填写`/video/submit`
|
||||||
|
2. Headers添加:
|
||||||
|
```text
|
||||||
|
Authorization: Bearer YOUR_API_KEY
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
3. Body示例(图像生成视频):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "Sunset scene with mountains",
|
||||||
|
"image_url": "https://example.com/mountain.jpg",
|
||||||
|
"image_size": "auto",
|
||||||
|
"num_frames": 50
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 特殊处理
|
||||||
|
- **图像下载失败**:返回400错误,包含具体原因(如URL无效/超时)
|
||||||
|
- **显存不足**:返回500错误并提示降低分辨率
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、参数规范
|
||||||
|
| 参数名 | 允许值范围 | 必填 | 说明 |
|
||||||
|
|------------------|-------------------------------|------|------------------------------------------|
|
||||||
|
| image_url | 有效HTTP/HTTPS URL | 是 | 输入图像地址 |
|
||||||
|
| prompt | 10-500字符 | 是 | 视频内容描述 |
|
||||||
|
| image_size | "480x832", "832x480", "auto" | 是 | auto模式自动适配原图比例 |
|
||||||
|
| num_frames | 24-120 | 是 | 视频总帧数 |
|
||||||
|
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
|
||||||
|
| infer_steps | 20-100 | 是 | 推理步数 |
|
||||||
|
| seed | 0-2147483647 | 否 | 随机种子 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、状态码说明
|
||||||
|
| 状态码 | 含义 |
|
||||||
|
|--------|-----------------------------------|
|
||||||
|
| 202 | 任务已接受 |
|
||||||
|
| 400 | 图像下载失败/参数错误 |
|
||||||
|
| 401 | 认证失败 |
|
||||||
|
| 404 | 任务不存在 |
|
||||||
|
| 422 | 参数校验失败 |
|
||||||
|
| 500 | 服务端错误(显存不足/模型异常等) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、特殊功能说明
|
||||||
|
1. **智能分辨率适配**:
|
||||||
|
- 当`image_size="auto"`时,自动计算符合模型要求的最优分辨率
|
||||||
|
- 保持原始图像宽高比,最大像素面积不超过399,360(约640x624)
|
||||||
|
|
||||||
|
2. **图像预处理**:
|
||||||
|
- 自动转换为RGB模式
|
||||||
|
- 根据目标分辨率进行等比缩放
|
||||||
|
|
||||||
|
|
||||||
|
**重要提示**:输入图像URL需保证公开可访问,私有资源需提供有效鉴权
|
||||||
|
|
||||||
|
**提示** :访问`http://服务器地址:8088/docs`可查看交互式API文档,支持在线测试所有接口
|
5
Makefile
Normal file
5
Makefile
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
.PHONY: format
|
||||||
|
|
||||||
|
format:
|
||||||
|
isort generate.py gradio wan
|
||||||
|
yapf -i -r *.py generate.py gradio wan
|
@ -643,7 +643,7 @@ If you find our work helpful, please cite us.
|
|||||||
```
|
```
|
||||||
@article{wan2025,
|
@article{wan2025,
|
||||||
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
||||||
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
||||||
journal = {arXiv preprint arXiv:2503.20314},
|
journal = {arXiv preprint arXiv:2503.20314},
|
||||||
year={2025}
|
year={2025}
|
||||||
}
|
}
|
||||||
|
133
T2V-FastAPI文档.md
Normal file
133
T2V-FastAPI文档.md
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
|
||||||
|
# 视频生成服务API文档
|
||||||
|
|
||||||
|
## 一、功能概述
|
||||||
|
本服务基于Wan2.1-T2V-1.3B模型实现文本到视频生成,包含以下核心功能:
|
||||||
|
1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
|
||||||
|
2. **资源管理**:
|
||||||
|
- 显存优化(使用bfloat16精度)
|
||||||
|
- 生成视频自动清理(默认1小时后删除)
|
||||||
|
3. **安全认证**:基于API Key的Bearer Token验证
|
||||||
|
4. **任务控制**:支持任务提交/状态查询/取消操作
|
||||||
|
|
||||||
|
技术栈:
|
||||||
|
- FastAPI框架
|
||||||
|
- CUDA加速
|
||||||
|
- 异步任务处理
|
||||||
|
- Diffusers推理库
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、接口说明
|
||||||
|
|
||||||
|
### 1. 提交生成任务
|
||||||
|
**POST /video/submit**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "Wan2.1-T2V-1.3B",
|
||||||
|
"prompt": "A beautiful sunset over the mountains",
|
||||||
|
"image_size": "480x832",
|
||||||
|
"num_frames": 81,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"infer_steps": 50
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 查询任务状态
|
||||||
|
**POST /video/status**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed",
|
||||||
|
"results": {
|
||||||
|
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
|
||||||
|
"timings": {"inference": 120}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 取消任务
|
||||||
|
**POST /video/cancel**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、Postman使用指南
|
||||||
|
|
||||||
|
### 1. 基础配置
|
||||||
|
- 服务器地址:`http://ip地址:8088`
|
||||||
|
- 认证方式:Bearer Token
|
||||||
|
- Token值:需替换为有效API Key
|
||||||
|
|
||||||
|
### 2. 提交任务
|
||||||
|
1. 选择POST方法,输入URL:`/video/submit`
|
||||||
|
2. Headers添加:
|
||||||
|
```text
|
||||||
|
Authorization: Bearer YOUR_API_KEY
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
3. Body选择raw/JSON格式,输入请求参数
|
||||||
|
|
||||||
|
### 3. 查询状态
|
||||||
|
1. 新建请求,URL填写`/video/status`
|
||||||
|
2. 使用相同认证头
|
||||||
|
3. Body中携带requestId
|
||||||
|
|
||||||
|
### 4. 取消任务
|
||||||
|
1. 新建DELETE请求,URL填写`/video/cancel`
|
||||||
|
2. Body携带需要取消的requestId
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
1. 所有接口必须携带有效API Key
|
||||||
|
2. 视频生成耗时约2-5分钟(根据参数配置)
|
||||||
|
3. 生成视频默认保留1小时
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、参数规范
|
||||||
|
| 参数名 | 允许值范围 | 必填 | 说明 |
|
||||||
|
|------------------|-------------------------------|------|--------------------------|
|
||||||
|
| prompt | 10-500字符 | 是 | 视频内容描述 |
|
||||||
|
| image_size | "480x832" 或 "832x480" | 是 | 分辨率 |
|
||||||
|
| num_frames | 24-120 | 是 | 视频总帧数 |
|
||||||
|
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
|
||||||
|
| infer_steps | 20-100 | 是 | 推理步数 |
|
||||||
|
| seed | 0-2147483647 | 否 | 随机种子 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、状态码说明
|
||||||
|
| 状态码 | 含义 |
|
||||||
|
|--------|--------------------------|
|
||||||
|
| 202 | 任务已接受 |
|
||||||
|
| 401 | 认证失败 |
|
||||||
|
| 404 | 任务不存在 |
|
||||||
|
| 422 | 参数校验失败 |
|
||||||
|
| 500 | 服务端错误(显存不足等) |
|
||||||
|
|
||||||
|
|
||||||
|
**提示**:建议使用Swagger文档进行接口测试,访问`http://服务器地址:8088/docs`可查看自动生成的API文档界面
|
67
generate.py
67
generate.py
@ -1,28 +1,33 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
import torch, random
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
"t2v-1.3B": {
|
"t2v-1.3B": {
|
||||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
},
|
},
|
||||||
"t2v-14B": {
|
"t2v-14B": {
|
||||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
},
|
},
|
||||||
"t2i-14B": {
|
"t2i-14B": {
|
||||||
"prompt": "一个朴素端庄的美人",
|
"prompt": "一个朴素端庄的美人",
|
||||||
@ -42,12 +47,16 @@ EXAMPLE_PROMPT = {
|
|||||||
"examples/flf2v_input_last_frame.png",
|
"examples/flf2v_input_last_frame.png",
|
||||||
},
|
},
|
||||||
"vace-1.3B": {
|
"vace-1.3B": {
|
||||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
"src_ref_images":
|
||||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
},
|
},
|
||||||
"vace-14B": {
|
"vace-14B": {
|
||||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
"src_ref_images":
|
||||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +73,6 @@ def _validate_args(args):
|
|||||||
if "i2v" in args.task:
|
if "i2v" in args.task:
|
||||||
args.sample_steps = 40
|
args.sample_steps = 40
|
||||||
|
|
||||||
|
|
||||||
if args.sample_shift is None:
|
if args.sample_shift is None:
|
||||||
args.sample_shift = 5.0
|
args.sample_shift = 5.0
|
||||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||||
@ -72,7 +80,6 @@ def _validate_args(args):
|
|||||||
elif "flf2v" in args.task or "vace" in args.task:
|
elif "flf2v" in args.task or "vace" in args.task:
|
||||||
args.sample_shift = 16
|
args.sample_shift = 16
|
||||||
|
|
||||||
|
|
||||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||||
if args.frame_num is None:
|
if args.frame_num is None:
|
||||||
args.frame_num = 1 if "t2i" in args.task else 81
|
args.frame_num = 1 if "t2i" in args.task else 81
|
||||||
@ -167,7 +174,8 @@ def _parse_args():
|
|||||||
"--src_ref_images",
|
"--src_ref_images",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file list of the source reference images. Separated by ','. Default None.")
|
help="The file list of the source reference images. Separated by ','. Default None."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
type=str,
|
type=str,
|
||||||
@ -209,12 +217,14 @@ def _parse_args():
|
|||||||
"--first_frame",
|
"--first_frame",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="[first-last frame to video] The image (first frame) to generate the video from.")
|
help="[first-last frame to video] The image (first frame) to generate the video from."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--last_frame",
|
"--last_frame",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="[first-last frame to video] The image (last frame) to generate the video from.")
|
help="[first-last frame to video] The image (last frame) to generate the video from."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_solver",
|
"--sample_solver",
|
||||||
type=str,
|
type=str,
|
||||||
@ -281,8 +291,10 @@ def generate(args):
|
|||||||
|
|
||||||
if args.ulysses_size > 1 or args.ring_size > 1:
|
if args.ulysses_size > 1 or args.ring_size > 1:
|
||||||
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
from xfuser.core.distributed import (
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
@ -295,7 +307,8 @@ def generate(args):
|
|||||||
if args.use_prompt_extend:
|
if args.use_prompt_extend:
|
||||||
if args.prompt_extend_method == "dashscope":
|
if args.prompt_extend_method == "dashscope":
|
||||||
prompt_expander = DashScopePromptExpander(
|
prompt_expander = DashScopePromptExpander(
|
||||||
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
|
model_name=args.prompt_extend_model,
|
||||||
|
is_vl="i2v" in args.task or "flf2v" in args.task)
|
||||||
elif args.prompt_extend_method == "local_qwen":
|
elif args.prompt_extend_method == "local_qwen":
|
||||||
prompt_expander = QwenPromptExpander(
|
prompt_expander = QwenPromptExpander(
|
||||||
model_name=args.prompt_extend_model,
|
model_name=args.prompt_extend_model,
|
||||||
@ -482,21 +495,22 @@ def generate(args):
|
|||||||
sampling_steps=args.sample_steps,
|
sampling_steps=args.sample_steps,
|
||||||
guide_scale=args.sample_guide_scale,
|
guide_scale=args.sample_guide_scale,
|
||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model
|
offload_model=args.offload_model)
|
||||||
)
|
|
||||||
elif "vace" in args.task:
|
elif "vace" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||||
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
||||||
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
|
args.src_ref_images = EXAMPLE_PROMPT[args.task].get(
|
||||||
|
"src_ref_images", None)
|
||||||
|
|
||||||
logging.info(f"Input prompt: {args.prompt}")
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
||||||
logging.info("Extending prompt ...")
|
logging.info("Extending prompt ...")
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
prompt = prompt_expander.forward(args.prompt)
|
prompt = prompt_expander.forward(args.prompt)
|
||||||
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
logging.info(
|
||||||
|
f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
||||||
input_prompt = [prompt]
|
input_prompt = [prompt]
|
||||||
else:
|
else:
|
||||||
input_prompt = [None]
|
input_prompt = [None]
|
||||||
@ -517,10 +531,11 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
|
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
|
||||||
[args.src_mask],
|
[args.src_video], [args.src_mask], [
|
||||||
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
|
None if args.src_ref_images is None else
|
||||||
args.frame_num, SIZE_CONFIGS[args.size], device)
|
args.src_ref_images.split(',')
|
||||||
|
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||||
|
|
||||||
logging.info(f"Generating video...")
|
logging.info(f"Generating video...")
|
||||||
video = wan_vace.generate(
|
video = wan_vace.generate(
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -11,7 +11,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
@ -69,13 +70,13 @@ def prompt_enc(prompt, img_first, img_last, tar_lang):
|
|||||||
return prompt_output.prompt
|
return prompt_output.prompt
|
||||||
|
|
||||||
|
|
||||||
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
|
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
guide_scale, shift_scale, seed, n_prompt):
|
resolution, sd_steps, guide_scale, shift_scale, seed,
|
||||||
|
n_prompt):
|
||||||
|
|
||||||
if resolution == '------':
|
if resolution == '------':
|
||||||
print(
|
print(
|
||||||
'Please specify the resolution ckpt dir or specify the resolution'
|
'Please specify the resolution ckpt dir or specify the resolution')
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -94,9 +95,7 @@ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, re
|
|||||||
offload_model=True)
|
offload_model=True)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(
|
print('Sorry, currently only 720P is supported.')
|
||||||
'Sorry, currently only 720P is supported.'
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cache_video(
|
cache_video(
|
||||||
@ -191,14 +190,17 @@ def gradio_interface():
|
|||||||
|
|
||||||
run_p_button.click(
|
run_p_button.click(
|
||||||
fn=prompt_enc,
|
fn=prompt_enc,
|
||||||
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
|
inputs=[
|
||||||
|
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
|
tar_lang
|
||||||
|
],
|
||||||
outputs=[flf2vid_prompt])
|
outputs=[flf2vid_prompt])
|
||||||
|
|
||||||
run_flf2v_button.click(
|
run_flf2v_button.click(
|
||||||
fn=flf2v_generation,
|
fn=flf2v_generation,
|
||||||
inputs=[
|
inputs=[
|
||||||
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
|
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
|
||||||
guide_scale, shift_scale, seed, n_prompt
|
resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
|
||||||
],
|
],
|
||||||
outputs=[result_gallery],
|
outputs=[result_gallery],
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -11,7 +11,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os.path as osp
|
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ import gradio as gr
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS
|
from wan.configs import WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
140
gradio/vace.py
140
gradio/vace.py
@ -2,36 +2,48 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import datetime
|
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
sys.path.insert(
|
||||||
|
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
import wan
|
import wan
|
||||||
from wan import WanVace, WanVaceMP
|
from wan import WanVace, WanVaceMP
|
||||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeQueue:
|
class FixedSizeQueue:
|
||||||
|
|
||||||
def __init__(self, max_size):
|
def __init__(self, max_size):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.queue = []
|
self.queue = []
|
||||||
|
|
||||||
def add(self, item):
|
def add(self, item):
|
||||||
self.queue.insert(0, item)
|
self.queue.insert(0, item)
|
||||||
if len(self.queue) > self.max_size:
|
if len(self.queue) > self.max_size:
|
||||||
self.queue.pop()
|
self.queue.pop()
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
return self.queue
|
return self.queue
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.queue)
|
return str(self.queue)
|
||||||
|
|
||||||
|
|
||||||
class VACEInference:
|
class VACEInference:
|
||||||
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
|
||||||
|
def __init__(self,
|
||||||
|
cfg,
|
||||||
|
skip_load=False,
|
||||||
|
gallery_share=True,
|
||||||
|
gallery_share_limit=5):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.save_dir = cfg.save_dir
|
self.save_dir = cfg.save_dir
|
||||||
self.gallery_share = gallery_share
|
self.gallery_share = gallery_share
|
||||||
@ -53,9 +65,7 @@ class VACEInference:
|
|||||||
checkpoint_dir=cfg.ckpt_dir,
|
checkpoint_dir=cfg.ckpt_dir,
|
||||||
use_usp=True,
|
use_usp=True,
|
||||||
ulysses_size=cfg.ulysses_size,
|
ulysses_size=cfg.ulysses_size,
|
||||||
ring_size=cfg.ring_size
|
ring_size=cfg.ring_size)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_ui(self, *args, **kwargs):
|
def create_ui(self, *args, **kwargs):
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@ -80,7 +90,8 @@ class VACEInference:
|
|||||||
with gr.Row(variant='panel', equal_height=True):
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
with gr.Column(scale=1, min_width=0):
|
with gr.Column(scale=1, min_width=0):
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
self.src_ref_image_1 = gr.Image(
|
||||||
|
label='src_ref_image_1',
|
||||||
height=200,
|
height=200,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
type='filepath',
|
type='filepath',
|
||||||
@ -88,7 +99,8 @@ class VACEInference:
|
|||||||
sources=['upload'],
|
sources=['upload'],
|
||||||
elem_id="src_ref_image_1",
|
elem_id="src_ref_image_1",
|
||||||
format='png')
|
format='png')
|
||||||
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
self.src_ref_image_2 = gr.Image(
|
||||||
|
label='src_ref_image_2',
|
||||||
height=200,
|
height=200,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
type='filepath',
|
type='filepath',
|
||||||
@ -96,7 +108,8 @@ class VACEInference:
|
|||||||
sources=['upload'],
|
sources=['upload'],
|
||||||
elem_id="src_ref_image_2",
|
elem_id="src_ref_image_2",
|
||||||
format='png')
|
format='png')
|
||||||
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
self.src_ref_image_3 = gr.Image(
|
||||||
|
label='src_ref_image_3',
|
||||||
height=200,
|
height=200,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
type='filepath',
|
type='filepath',
|
||||||
@ -158,10 +171,8 @@ class VACEInference:
|
|||||||
step=0.5,
|
step=0.5,
|
||||||
value=5.0,
|
value=5.0,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
self.infer_seed = gr.Slider(minimum=-1,
|
self.infer_seed = gr.Slider(
|
||||||
maximum=10000000,
|
minimum=-1, maximum=10000000, value=2025, label="Seed")
|
||||||
value=2025,
|
|
||||||
label="Seed")
|
|
||||||
#
|
#
|
||||||
with gr.Accordion(label="Usable without source video", open=False):
|
with gr.Accordion(label="Usable without source video", open=False):
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
@ -176,13 +187,9 @@ class VACEInference:
|
|||||||
value=1280,
|
value=1280,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
self.frame_rate = gr.Textbox(
|
self.frame_rate = gr.Textbox(
|
||||||
label='frame_rate',
|
label='frame_rate', value=16, interactive=True)
|
||||||
value=16,
|
|
||||||
interactive=True)
|
|
||||||
self.num_frames = gr.Textbox(
|
self.num_frames = gr.Textbox(
|
||||||
label='num_frames',
|
label='num_frames', value=81, interactive=True)
|
||||||
value=81,
|
|
||||||
interactive=True)
|
|
||||||
#
|
#
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
@ -201,14 +208,19 @@ class VACEInference:
|
|||||||
allow_preview=True,
|
allow_preview=True,
|
||||||
preview=True)
|
preview=True)
|
||||||
|
|
||||||
|
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
|
||||||
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
|
||||||
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
shift_scale, sample_steps, context_scale, guide_scale,
|
||||||
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
infer_seed, output_height, output_width, frame_rate,
|
||||||
x is not None]
|
num_frames):
|
||||||
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
output_height, output_width, frame_rate, num_frames = int(
|
||||||
[src_mask],
|
output_height), int(output_width), int(frame_rate), int(num_frames)
|
||||||
[src_ref_images],
|
src_ref_images = [
|
||||||
|
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
|
||||||
|
if x is not None
|
||||||
|
]
|
||||||
|
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
|
||||||
|
[src_video], [src_mask], [src_ref_images],
|
||||||
num_frames=num_frames,
|
num_frames=num_frames,
|
||||||
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
||||||
device=self.pipe.device)
|
device=self.pipe.device)
|
||||||
@ -228,10 +240,17 @@ class VACEInference:
|
|||||||
|
|
||||||
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
||||||
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
||||||
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
video_frames = (
|
||||||
|
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
|
||||||
|
255).cpu().numpy().astype(np.uint8)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
writer = imageio.get_writer(
|
||||||
|
video_path,
|
||||||
|
fps=frame_rate,
|
||||||
|
codec='libx264',
|
||||||
|
quality=8,
|
||||||
|
macro_block_size=1)
|
||||||
for frame in video_frames:
|
for frame in video_frames:
|
||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
writer.close()
|
writer.close()
|
||||||
@ -246,25 +265,57 @@ class VACEInference:
|
|||||||
return [video_path]
|
return [video_path]
|
||||||
|
|
||||||
def set_callbacks(self, **kwargs):
|
def set_callbacks(self, **kwargs):
|
||||||
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
self.gen_inputs = [
|
||||||
|
self.output_gallery, self.src_video, self.src_mask,
|
||||||
|
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
|
||||||
|
self.prompt, self.negative_prompt, self.shift_scale,
|
||||||
|
self.sample_steps, self.context_scale, self.guide_scale,
|
||||||
|
self.infer_seed, self.output_height, self.output_width,
|
||||||
|
self.frame_rate, self.num_frames
|
||||||
|
]
|
||||||
self.gen_outputs = [self.output_gallery]
|
self.gen_outputs = [self.output_gallery]
|
||||||
self.generate_button.click(self.generate,
|
self.generate_button.click(
|
||||||
|
self.generate,
|
||||||
inputs=self.gen_inputs,
|
inputs=self.gen_inputs,
|
||||||
outputs=self.gen_outputs,
|
outputs=self.gen_outputs,
|
||||||
queue=True)
|
queue=True)
|
||||||
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
self.refresh_button.click(
|
||||||
|
lambda x: self.gallery_share_data.get()
|
||||||
|
if self.gallery_share else x,
|
||||||
|
inputs=[self.output_gallery],
|
||||||
|
outputs=[self.output_gallery])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
description='Argparser for VACE-WAN Demo:\n')
|
||||||
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
parser.add_argument(
|
||||||
|
'--server_port', dest='server_port', help='', type=int, default=7860)
|
||||||
|
parser.add_argument(
|
||||||
|
'--server_name', dest='server_name', help='', default='0.0.0.0')
|
||||||
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
||||||
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
||||||
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
|
parser.add_argument(
|
||||||
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
|
"--mp",
|
||||||
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
|
action="store_true",
|
||||||
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
|
help="Use Multi-GPUs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="vace-14B",
|
||||||
|
choices=list(WAN_CONFIGS.keys()),
|
||||||
|
help="The model name to run.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ulysses_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ulysses parallelism in DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ring_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ring attention parallelism in DiT.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_dir",
|
"--ckpt_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -284,12 +335,15 @@ if __name__ == '__main__':
|
|||||||
os.makedirs(args.save_dir, exist_ok=True)
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
infer_gr = VACEInference(
|
||||||
|
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
||||||
infer_gr.create_ui()
|
infer_gr.create_ui()
|
||||||
infer_gr.set_callbacks()
|
infer_gr.set_callbacks()
|
||||||
allowed_paths = [args.save_dir]
|
allowed_paths = [args.save_dir]
|
||||||
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
demo.queue(status_update_rate=1).launch(
|
||||||
|
server_name=args.server_name,
|
||||||
server_port=args.server_port,
|
server_port=args.server_port,
|
||||||
root_path=args.root_path,
|
root_path=args.root_path,
|
||||||
allowed_paths=allowed_paths,
|
allowed_paths=allowed_paths,
|
||||||
show_error=True, debug=True)
|
show_error=True,
|
||||||
|
debug=True)
|
||||||
|
526
i2v_api.py
Normal file
526
i2v_api.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import numpy as np
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
from fastapi import FastAPI, HTTPException, status, Depends
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ValidationError
|
||||||
|
from diffusers.utils import export_to_video, load_image
|
||||||
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||||
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||||
|
from transformers import CLIPVisionModel
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
|
# 创建存储目录
|
||||||
|
os.makedirs("generated_videos", exist_ok=True)
|
||||||
|
os.makedirs("temp_images", exist_ok=True)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 生命周期管理
|
||||||
|
# ======================
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""资源管理器"""
|
||||||
|
try:
|
||||||
|
# 初始化认证系统
|
||||||
|
app.state.valid_api_keys = {
|
||||||
|
"密钥"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
|
||||||
|
|
||||||
|
# 加载图像编码器
|
||||||
|
image_encoder = CLIPVisionModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="image_encoder",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载VAE
|
||||||
|
vae = AutoencoderKLWan.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置调度器
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
prediction_type='flow_prediction',
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
flow_shift=3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建管道
|
||||||
|
app.state.pipe = WanImageToVideoPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
vae=vae,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).to("cuda")
|
||||||
|
app.state.pipe.scheduler = scheduler
|
||||||
|
|
||||||
|
# 初始化任务系统
|
||||||
|
app.state.tasks: Dict[str, dict] = {}
|
||||||
|
app.state.pending_queue: List[str] = []
|
||||||
|
app.state.model_lock = Lock()
|
||||||
|
app.state.task_lock = Lock()
|
||||||
|
app.state.base_url = "ip地址+端口"
|
||||||
|
app.state.semaphore = asyncio.Semaphore(2) # 并发限制
|
||||||
|
|
||||||
|
# 启动后台处理器
|
||||||
|
asyncio.create_task(task_processor())
|
||||||
|
|
||||||
|
print("✅ 系统初始化完成")
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 资源清理
|
||||||
|
if hasattr(app.state, 'pipe'):
|
||||||
|
del app.state.pipe
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("♻️ 资源已释放")
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# FastAPI应用
|
||||||
|
# ======================
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
|
||||||
|
# 认证模块
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 数据模型--查询参数模型
|
||||||
|
# ======================
|
||||||
|
class VideoSubmitRequest(BaseModel):
|
||||||
|
model: str = Field(
|
||||||
|
default="Wan2.1-I2V-14B-480P",
|
||||||
|
description="模型版本"
|
||||||
|
)
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
max_length=500,
|
||||||
|
description="视频描述提示词,10-500个字符"
|
||||||
|
)
|
||||||
|
image_url: str = Field(
|
||||||
|
...,
|
||||||
|
description="输入图像URL,需支持HTTP/HTTPS协议"
|
||||||
|
)
|
||||||
|
image_size: str = Field(
|
||||||
|
default="auto",
|
||||||
|
description="输出分辨率,格式:宽x高 或 auto(自动计算)"
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=500,
|
||||||
|
description="排除不需要的内容"
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0,
|
||||||
|
le=2147483647,
|
||||||
|
description="随机数种子,范围0-2147483647"
|
||||||
|
)
|
||||||
|
num_frames: int = Field(
|
||||||
|
default=81,
|
||||||
|
ge=24,
|
||||||
|
le=120,
|
||||||
|
description="视频帧数,24-89帧"
|
||||||
|
)
|
||||||
|
guidance_scale: float = Field(
|
||||||
|
default=3.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=20.0,
|
||||||
|
description="引导系数,1.0-20.0"
|
||||||
|
)
|
||||||
|
infer_steps: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=20,
|
||||||
|
le=100,
|
||||||
|
description="推理步数,20-100步"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('image_size')
|
||||||
|
def validate_image_size(cls, v):
|
||||||
|
allowed_sizes = {"480x832", "832x480", "auto"}
|
||||||
|
if v not in allowed_sizes:
|
||||||
|
raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class VideoStatusRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
class VideoStatusResponse(BaseModel):
|
||||||
|
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
|
||||||
|
reason: Optional[str] = Field(None, description="失败原因")
|
||||||
|
results: Optional[dict] = Field(None, description="生成结果")
|
||||||
|
queue_position: Optional[int] = Field(None, description="队列位置")
|
||||||
|
|
||||||
|
class VideoCancelRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 核心逻辑
|
||||||
|
# ======================
|
||||||
|
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
"""统一认证验证"""
|
||||||
|
if not credentials:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "缺少认证头"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.scheme != "Bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的认证方案"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.credentials not in app.state.valid_api_keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的API密钥"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def task_processor():
|
||||||
|
"""任务处理器"""
|
||||||
|
while True:
|
||||||
|
async with app.state.semaphore:
|
||||||
|
task_id = await get_next_task()
|
||||||
|
if task_id:
|
||||||
|
await process_task(task_id)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def get_next_task():
|
||||||
|
"""获取下一个任务"""
|
||||||
|
with app.state.task_lock:
|
||||||
|
return app.state.pending_queue.pop(0) if app.state.pending_queue else None
|
||||||
|
|
||||||
|
async def process_task(task_id: str):
|
||||||
|
"""处理单个任务"""
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新任务状态
|
||||||
|
task['status'] = 'InProgress'
|
||||||
|
task['started_at'] = int(time.time())
|
||||||
|
print(task['request'].image_url)
|
||||||
|
# 下载输入图像
|
||||||
|
image = await download_image(task['request'].image_url)
|
||||||
|
image_path = f"temp_images/{task_id}.jpg"
|
||||||
|
image.save(image_path)
|
||||||
|
|
||||||
|
# 生成视频
|
||||||
|
video_path = await generate_video(task['request'], task_id, image)
|
||||||
|
|
||||||
|
# 生成下载链接
|
||||||
|
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Succeed',
|
||||||
|
'download_url': download_url,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# 安排清理
|
||||||
|
asyncio.create_task(cleanup_files([image_path, video_path]))
|
||||||
|
except Exception as e:
|
||||||
|
handle_task_error(task, e)
|
||||||
|
|
||||||
|
def handle_task_error(task: dict, error: Exception):
|
||||||
|
"""错误处理(包含详细错误信息)"""
|
||||||
|
error_msg = str(error)
|
||||||
|
|
||||||
|
# 1. 显存不足错误
|
||||||
|
if isinstance(error, torch.cuda.OutOfMemoryError):
|
||||||
|
error_msg = "显存不足,请降低分辨率"
|
||||||
|
|
||||||
|
# 2. 网络请求相关错误
|
||||||
|
elif isinstance(error, (RequestException, HTTPException)):
|
||||||
|
# 从异常中提取具体信息
|
||||||
|
if isinstance(error, HTTPException):
|
||||||
|
# 如果是 HTTPException,获取其 detail 字段
|
||||||
|
error_detail = getattr(error, "detail", "")
|
||||||
|
error_msg = f"图像下载失败: {error_detail}"
|
||||||
|
|
||||||
|
elif isinstance(error, Timeout):
|
||||||
|
error_msg = "图像下载超时,请检查网络"
|
||||||
|
|
||||||
|
elif isinstance(error, ConnectionError):
|
||||||
|
error_msg = "无法连接到服务器,请检查 URL"
|
||||||
|
|
||||||
|
elif isinstance(error, HTTPError):
|
||||||
|
# requests 的 HTTPError(例如 4xx/5xx 状态码)
|
||||||
|
status_code = error.response.status_code
|
||||||
|
error_msg = f"服务器返回错误状态码: {status_code}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 其他 RequestException 错误
|
||||||
|
error_msg = f"图像下载失败: {str(error)}"
|
||||||
|
|
||||||
|
# 3. 其他未知错误
|
||||||
|
else:
|
||||||
|
error_msg = f"未知错误: {str(error)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Failed',
|
||||||
|
'reason': error_msg,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
# ======================
|
||||||
|
# 视频生成逻辑
|
||||||
|
# ======================
|
||||||
|
async def download_image(url: str) -> Image.Image:
|
||||||
|
"""异步下载图像(包含详细错误信息)"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: requests.get(url) # 将 timeout 传递给 requests.get
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果状态码非 200,主动抛出 HTTPException
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail=f"服务器返回状态码 {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||||
|
|
||||||
|
except RequestException as e:
|
||||||
|
# 将原始 requests 错误信息抛出
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"请求失败: {str(e)}"
|
||||||
|
)
|
||||||
|
async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
|
||||||
|
"""异步生成入口"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
sync_generate_video,
|
||||||
|
request,
|
||||||
|
task_id,
|
||||||
|
image
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
|
||||||
|
"""同步生成核心"""
|
||||||
|
with app.state.model_lock:
|
||||||
|
try:
|
||||||
|
# 解析分辨率
|
||||||
|
mod_value = 16 # 模型要求的模数
|
||||||
|
print(request.image_size)
|
||||||
|
print('--------------------------------')
|
||||||
|
if request.image_size == "auto":
|
||||||
|
# 原版自动计算逻辑
|
||||||
|
aspect_ratio = image.height / image.width
|
||||||
|
print(image.height,image.width)
|
||||||
|
max_area = 399360 # 模型基础分辨率
|
||||||
|
|
||||||
|
# 计算理想尺寸
|
||||||
|
height = round(np.sqrt(max_area * aspect_ratio))
|
||||||
|
width = round(np.sqrt(max_area / aspect_ratio))
|
||||||
|
|
||||||
|
# 应用模数调整
|
||||||
|
height = height // mod_value * mod_value
|
||||||
|
width = width // mod_value * mod_value
|
||||||
|
resized_image = image.resize((width, height))
|
||||||
|
else:
|
||||||
|
width_str, height_str = request.image_size.split('x')
|
||||||
|
width = int(width_str)
|
||||||
|
height = int(height_str)
|
||||||
|
mod_value = 16
|
||||||
|
# 调整图像尺寸
|
||||||
|
resized_image = image.resize((width, height))
|
||||||
|
|
||||||
|
|
||||||
|
# 设置随机种子
|
||||||
|
generator = None
|
||||||
|
# 修改点1: 使用属性访问seed
|
||||||
|
if request.seed is not None:
|
||||||
|
generator = torch.Generator(device="cuda")
|
||||||
|
generator.manual_seed(request.seed) # 修改点2
|
||||||
|
print(f"🔮 使用随机种子: {request.seed}")
|
||||||
|
print(resized_image)
|
||||||
|
print(height,width)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
output = app.state.pipe(
|
||||||
|
image=resized_image,
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=request.num_frames,
|
||||||
|
guidance_scale=request.guidance_scale,
|
||||||
|
num_inference_steps=request.infer_steps,
|
||||||
|
generator=generator
|
||||||
|
).frames[0]
|
||||||
|
|
||||||
|
# 导出视频
|
||||||
|
video_id = uuid.uuid4().hex
|
||||||
|
output_path = f"generated_videos/{video_id}.mp4"
|
||||||
|
export_to_video(output, output_path, fps=16)
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"视频生成失败: {str(e)}") from e
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# API端点
|
||||||
|
# ======================
|
||||||
|
@app.post("/video/submit",
|
||||||
|
response_model=dict,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def submit_task(
|
||||||
|
request: VideoSubmitRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""提交生成任务"""
|
||||||
|
# 参数验证
|
||||||
|
if request.image_url is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={"status": "Failed", "reason": "需要图像URL参数"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task_id = uuid.uuid4().hex
|
||||||
|
with app.state.task_lock:
|
||||||
|
app.state.tasks[task_id] = {
|
||||||
|
"request": request,
|
||||||
|
"status": "InQueue",
|
||||||
|
"created_at": int(time.time())
|
||||||
|
}
|
||||||
|
app.state.pending_queue.append(task_id)
|
||||||
|
|
||||||
|
return {"requestId": task_id}
|
||||||
|
|
||||||
|
@app.post("/video/status",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def get_status(
|
||||||
|
request: VideoStatusRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""查询任务状态"""
|
||||||
|
task = app.state.tasks.get(request.requestId)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算队列位置(仅当在队列中时)
|
||||||
|
queue_pos = 0
|
||||||
|
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
|
||||||
|
queue_pos = app.state.pending_queue.index(request.requestId) + 1
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": task['status'],
|
||||||
|
"reason": task.get('reason'),
|
||||||
|
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
|
||||||
|
}
|
||||||
|
|
||||||
|
# 成功状态的特殊处理
|
||||||
|
if task['status'] == "Succeed":
|
||||||
|
response["results"] = {
|
||||||
|
"videos": [{"url": task['download_url']}],
|
||||||
|
"timings": {
|
||||||
|
"inference": task['completed_at'] - task['started_at']
|
||||||
|
},
|
||||||
|
"seed": task['request'].seed
|
||||||
|
}
|
||||||
|
# 取消状态的补充信息
|
||||||
|
elif task['status'] == "Cancelled":
|
||||||
|
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.post("/video/cancel",
|
||||||
|
response_model=dict,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def cancel_task(
|
||||||
|
request: VideoCancelRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""取消排队中的生成任务"""
|
||||||
|
task_id = request.requestId
|
||||||
|
|
||||||
|
with app.state.task_lock:
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_status = task['status']
|
||||||
|
|
||||||
|
# 仅允许取消排队中的任务
|
||||||
|
if current_status != "InQueue":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从队列移除
|
||||||
|
try:
|
||||||
|
app.state.pending_queue.remove(task_id)
|
||||||
|
except ValueError:
|
||||||
|
pass # 可能已被处理
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
"status": "Cancelled",
|
||||||
|
"reason": "用户主动取消",
|
||||||
|
"completed_at": int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "Succeed"}
|
||||||
|
|
||||||
|
async def cleanup_files(paths: List[str], delay: int = 3600):
|
||||||
|
"""定时清理文件"""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理失败 {path}: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8088)
|
450
t2v-api.py
Normal file
450
t2v-api.py
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, HTTPException, status, Depends
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ValidationError
|
||||||
|
from diffusers.utils import export_to_video
|
||||||
|
from diffusers import AutoencoderKLWan, WanPipeline
|
||||||
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
# 创建视频存储目录
|
||||||
|
os.makedirs("generated_videos", exist_ok=True)
|
||||||
|
|
||||||
|
# 生命周期管理器
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""管理应用生命周期"""
|
||||||
|
# 初始化模型和资源
|
||||||
|
try:
|
||||||
|
# 初始化认证密钥
|
||||||
|
app.state.valid_api_keys = {
|
||||||
|
"密钥"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化视频生成模型
|
||||||
|
model_id = "./Wan2.1-T2V-1.3B-Diffusers"
|
||||||
|
vae = AutoencoderKLWan.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
prediction_type='flow_prediction',
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
flow_shift=3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.pipe = WanPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
vae=vae,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).to("cuda")
|
||||||
|
app.state.pipe.scheduler = scheduler
|
||||||
|
|
||||||
|
# 初始化任务系统
|
||||||
|
app.state.tasks: Dict[str, dict] = {}
|
||||||
|
app.state.pending_queue: List[str] = []
|
||||||
|
app.state.model_lock = Lock()
|
||||||
|
app.state.task_lock = Lock()
|
||||||
|
app.state.base_url = "ip地址+端口"
|
||||||
|
app.state.max_concurrent = 2
|
||||||
|
app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
|
||||||
|
|
||||||
|
# 启动后台任务处理器
|
||||||
|
asyncio.create_task(task_processor())
|
||||||
|
|
||||||
|
print("✅ 应用初始化完成")
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if hasattr(app.state, 'pipe'):
|
||||||
|
del app.state.pipe
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("♻️ 已释放模型资源")
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
|
||||||
|
|
||||||
|
# 认证模块
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 数据模型--查询参数模型
|
||||||
|
# ======================
|
||||||
|
class VideoSubmitRequest(BaseModel):
|
||||||
|
model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
max_length=500,
|
||||||
|
description="视频描述提示词,10-500个字符"
|
||||||
|
)
|
||||||
|
image_size: str = Field(
|
||||||
|
...,
|
||||||
|
description="视频分辨率,仅支持480x832或832x480"
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=500,
|
||||||
|
description="排除不需要的内容"
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0,
|
||||||
|
le=2147483647,
|
||||||
|
description="随机数种子,范围0-2147483647"
|
||||||
|
)
|
||||||
|
num_frames: int = Field(
|
||||||
|
default=81,
|
||||||
|
ge=24,
|
||||||
|
le=120,
|
||||||
|
description="视频帧数,24-120帧"
|
||||||
|
)
|
||||||
|
guidance_scale: float = Field(
|
||||||
|
default=5.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=20.0,
|
||||||
|
description="引导系数,1.0-20.0"
|
||||||
|
)
|
||||||
|
infer_steps: int = Field(
|
||||||
|
default=50,
|
||||||
|
ge=20,
|
||||||
|
le=100,
|
||||||
|
description="推理步数,20-100步"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('image_size', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_image_size(cls, value):
|
||||||
|
allowed_sizes = {"480x832", "832x480"}
|
||||||
|
if value not in allowed_sizes:
|
||||||
|
raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
class VideoStatusRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
class VideoSubmitResponse(BaseModel):
|
||||||
|
requestId: str
|
||||||
|
|
||||||
|
class VideoStatusResponse(BaseModel):
|
||||||
|
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
|
||||||
|
reason: Optional[str] = Field(None, description="失败原因")
|
||||||
|
results: Optional[dict] = Field(None, description="生成结果")
|
||||||
|
queue_position: Optional[int] = Field(None, description="队列位置")
|
||||||
|
|
||||||
|
class VideoCancelRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# # 自定义HTTP异常处理器
|
||||||
|
# @app.exception_handler(HTTPException)
|
||||||
|
# async def http_exception_handler(request, exc):
|
||||||
|
# return JSONResponse(
|
||||||
|
# status_code=exc.status_code,
|
||||||
|
# content=exc.detail, # 直接返回detail内容(不再包装在detail字段)
|
||||||
|
# headers=exc.headers
|
||||||
|
# )
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 后台任务处理
|
||||||
|
# ======================
|
||||||
|
async def task_processor():
|
||||||
|
"""处理任务队列"""
|
||||||
|
while True:
|
||||||
|
async with app.state.semaphore:
|
||||||
|
task_id = await get_next_task()
|
||||||
|
if task_id:
|
||||||
|
await process_task(task_id)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def get_next_task():
|
||||||
|
"""获取下一个待处理任务"""
|
||||||
|
with app.state.task_lock:
|
||||||
|
if app.state.pending_queue:
|
||||||
|
return app.state.pending_queue.pop(0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_task(task_id: str):
|
||||||
|
"""处理单个任务"""
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新任务状态
|
||||||
|
task['status'] = 'InProgress'
|
||||||
|
task['started_at'] = int(time.time())
|
||||||
|
|
||||||
|
# 执行视频生成
|
||||||
|
video_path = await generate_video(task['request'], task_id)
|
||||||
|
|
||||||
|
# 生成下载链接
|
||||||
|
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Succeed',
|
||||||
|
'download_url': download_url,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# 安排自动清理
|
||||||
|
asyncio.create_task(auto_cleanup(video_path))
|
||||||
|
except Exception as e:
|
||||||
|
handle_task_error(task, e)
|
||||||
|
|
||||||
|
def handle_task_error(task: dict, error: Exception):
|
||||||
|
"""统一处理任务错误"""
|
||||||
|
error_msg = str(error)
|
||||||
|
if isinstance(error, torch.cuda.OutOfMemoryError):
|
||||||
|
error_msg = "显存不足,请降低分辨率或减少帧数"
|
||||||
|
elif isinstance(error, ValidationError):
|
||||||
|
error_msg = "参数校验失败: " + str(error)
|
||||||
|
|
||||||
|
task.update({
|
||||||
|
'status': 'Failed',
|
||||||
|
'reason': error_msg,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 视频生成核心逻辑
|
||||||
|
# ======================
|
||||||
|
async def generate_video(request: dict, task_id: str) -> str:
|
||||||
|
"""异步执行视频生成"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
sync_generate_video,
|
||||||
|
request,
|
||||||
|
task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_generate_video(request: dict, task_id: str) -> str:
|
||||||
|
"""同步生成视频"""
|
||||||
|
with app.state.model_lock:
|
||||||
|
try:
|
||||||
|
generator = None
|
||||||
|
if request.get('seed') is not None:
|
||||||
|
generator = torch.Generator(device="cuda")
|
||||||
|
generator.manual_seed(request['seed'])
|
||||||
|
print(f"🔮 使用随机种子: {request['seed']}")
|
||||||
|
|
||||||
|
# 执行模型推理
|
||||||
|
result = app.state.pipe(
|
||||||
|
prompt=request['prompt'],
|
||||||
|
negative_prompt=request['negative_prompt'],
|
||||||
|
height=request['height'],
|
||||||
|
width=request['width'],
|
||||||
|
num_frames=request['num_frames'],
|
||||||
|
guidance_scale=request['guidance_scale'],
|
||||||
|
num_inference_steps=request['infer_steps'],
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出视频文件
|
||||||
|
video_id = uuid.uuid4().hex
|
||||||
|
output_path = f"generated_videos/{video_id}.mp4"
|
||||||
|
export_to_video(result.frames[0], output_path, fps=16)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"视频生成失败: {str(e)}") from e
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# API端点
|
||||||
|
# ======================
|
||||||
|
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
"""认证验证"""
|
||||||
|
if not credentials:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "缺少认证头"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.scheme != "Bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的认证方案"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.credentials not in app.state.valid_api_keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的API密钥"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@app.post("/video/submit",
|
||||||
|
response_model=VideoSubmitResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
tags=["视频生成"],
|
||||||
|
summary="提交视频生成请求")
|
||||||
|
async def submit_video_task(
|
||||||
|
request: VideoSubmitRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""提交新的视频生成任务"""
|
||||||
|
try:
|
||||||
|
# 解析分辨率参数
|
||||||
|
width, height = map(int, request.image_size.split('x'))
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task_id = uuid.uuid4().hex
|
||||||
|
task_data = {
|
||||||
|
'request': {
|
||||||
|
'prompt': request.prompt,
|
||||||
|
'negative_prompt': request.negative_prompt,
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'num_frames': request.num_frames,
|
||||||
|
'guidance_scale': request.guidance_scale,
|
||||||
|
'infer_steps': request.infer_steps,
|
||||||
|
'seed': request.seed
|
||||||
|
},
|
||||||
|
'status': 'InQueue',
|
||||||
|
'created_at': int(time.time())
|
||||||
|
}
|
||||||
|
|
||||||
|
# 加入任务队列
|
||||||
|
with app.state.task_lock:
|
||||||
|
app.state.tasks[task_id] = task_data
|
||||||
|
app.state.pending_queue.append(task_id)
|
||||||
|
|
||||||
|
return {"requestId": task_id}
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={"status": "Failed", "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.post("/video/status",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
tags=["视频生成"],
|
||||||
|
summary="查询任务状态")
|
||||||
|
async def get_video_status(
|
||||||
|
request: VideoStatusRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""查询任务状态"""
|
||||||
|
task = app.state.tasks.get(request.requestId)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算队列位置(仅当在队列中时)
|
||||||
|
queue_pos = 0
|
||||||
|
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
|
||||||
|
queue_pos = app.state.pending_queue.index(request.requestId) + 1
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": task['status'],
|
||||||
|
"reason": task.get('reason'),
|
||||||
|
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
|
||||||
|
}
|
||||||
|
|
||||||
|
# 成功状态的特殊处理
|
||||||
|
if task['status'] == "Succeed":
|
||||||
|
response["results"] = {
|
||||||
|
"videos": [{"url": task['download_url']}],
|
||||||
|
"timings": {
|
||||||
|
"inference": task['completed_at'] - task['started_at']
|
||||||
|
},
|
||||||
|
"seed": task['request']['seed']
|
||||||
|
}
|
||||||
|
# 取消状态的补充信息
|
||||||
|
elif task['status'] == "Cancelled":
|
||||||
|
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/video/cancel",
|
||||||
|
response_model=dict,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def cancel_task(
|
||||||
|
request: VideoCancelRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""取消排队中的生成任务"""
|
||||||
|
task_id = request.requestId
|
||||||
|
|
||||||
|
with app.state.task_lock:
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_status = task['status']
|
||||||
|
|
||||||
|
# 仅允许取消排队中的任务
|
||||||
|
if current_status != "InQueue":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从队列移除
|
||||||
|
try:
|
||||||
|
app.state.pending_queue.remove(task_id)
|
||||||
|
except ValueError:
|
||||||
|
pass # 可能已被处理
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
"status": "Cancelled",
|
||||||
|
"reason": "用户主动取消",
|
||||||
|
"completed_at": int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "Succeed"}
|
||||||
|
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 工具函数
|
||||||
|
# ======================
|
||||||
|
async def auto_cleanup(file_path: str, delay: int = 3600):
|
||||||
|
"""自动清理生成的视频文件"""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"已清理文件: {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"文件清理失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8088)
|
@ -1,5 +1,5 @@
|
|||||||
from . import configs, distributed, modules
|
from . import configs, distributed, modules
|
||||||
|
from .first_last_frame2video import WanFLF2V
|
||||||
from .image2video import WanI2V
|
from .image2video import WanI2V
|
||||||
from .text2video import WanT2V
|
from .text2video import WanT2V
|
||||||
from .first_last_frame2video import WanFLF2V
|
|
||||||
from .vace import WanVace, WanVaceMP
|
from .vace import WanVace, WanVaceMP
|
||||||
|
@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
|||||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||||
from torch.distributed.utils import _free_storage
|
from torch.distributed.utils import _free_storage
|
||||||
|
|
||||||
|
|
||||||
def shard_model(
|
def shard_model(
|
||||||
model,
|
model,
|
||||||
device_id,
|
device_id,
|
||||||
@ -32,6 +33,7 @@ def shard_model(
|
|||||||
sync_module_states=sync_module_states)
|
sync_module_states=sync_module_states)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def free_model(model):
|
def free_model(model):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, FSDP):
|
if isinstance(m, FSDP):
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
from xfuser.core.distributed import (
|
||||||
|
get_sequence_parallel_rank,
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group,
|
||||||
|
)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
from ..modules.model import sinusoidal_embedding_1d
|
from ..modules.model import sinusoidal_embedding_1d
|
||||||
@ -63,19 +65,13 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward_vace(
|
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
vace_context,
|
|
||||||
seq_len,
|
|
||||||
kwargs
|
|
||||||
):
|
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
c = torch.cat([
|
c = torch.cat([
|
||||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
||||||
dim=1) for u in c
|
for u in c
|
||||||
])
|
])
|
||||||
|
|
||||||
# arguments
|
# arguments
|
||||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -103,11 +106,12 @@ class WanFLF2V:
|
|||||||
init_on_cpu = False
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -181,8 +185,10 @@ class WanFLF2V:
|
|||||||
"""
|
"""
|
||||||
first_frame_size = first_frame.size
|
first_frame_size = first_frame.size
|
||||||
last_frame_size = last_frame.size
|
last_frame_size = last_frame.size
|
||||||
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
|
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
||||||
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
|
self.device)
|
||||||
|
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
F = frame_num
|
F = frame_num
|
||||||
first_frame_h, first_frame_w = first_frame.shape[1:]
|
first_frame_h, first_frame_w = first_frame.shape[1:]
|
||||||
@ -199,8 +205,7 @@ class WanFLF2V:
|
|||||||
# 1. resize
|
# 1. resize
|
||||||
last_frame_resize_ratio = max(
|
last_frame_resize_ratio = max(
|
||||||
first_frame_size[0] / last_frame_size[0],
|
first_frame_size[0] / last_frame_size[0],
|
||||||
first_frame_size[1] / last_frame_size[1]
|
first_frame_size[1] / last_frame_size[1])
|
||||||
)
|
|
||||||
last_frame_size = [
|
last_frame_size = [
|
||||||
round(last_frame_size[0] * last_frame_resize_ratio),
|
round(last_frame_size[0] * last_frame_resize_ratio),
|
||||||
round(last_frame_size[1] * last_frame_resize_ratio),
|
round(last_frame_size[1] * last_frame_resize_ratio),
|
||||||
@ -216,8 +221,7 @@ class WanFLF2V:
|
|||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
seed_g.manual_seed(seed)
|
seed_g.manual_seed(seed)
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
16,
|
16, (F - 1) // 4 + 1,
|
||||||
(F - 1) // 4 + 1,
|
|
||||||
lat_h,
|
lat_h,
|
||||||
lat_w,
|
lat_w,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -226,7 +230,10 @@ class WanFLF2V:
|
|||||||
|
|
||||||
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
||||||
msk[:, 1:-1] = 0
|
msk[:, 1:-1] = 0
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
msk = torch.concat([
|
||||||
|
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
@ -247,7 +254,8 @@ class WanFLF2V:
|
|||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
self.clip.model.to(self.device)
|
self.clip.model.to(self.device)
|
||||||
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
|
clip_context = self.clip.visual(
|
||||||
|
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.clip.model.cpu()
|
self.clip.model.cpu()
|
||||||
|
|
||||||
@ -256,15 +264,14 @@ class WanFLF2V:
|
|||||||
torch.nn.functional.interpolate(
|
torch.nn.functional.interpolate(
|
||||||
first_frame[None].cpu(),
|
first_frame[None].cpu(),
|
||||||
size=(first_frame_h, first_frame_w),
|
size=(first_frame_h, first_frame_w),
|
||||||
mode='bicubic'
|
mode='bicubic').transpose(0, 1),
|
||||||
).transpose(0, 1),
|
|
||||||
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
|
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
|
||||||
torch.nn.functional.interpolate(
|
torch.nn.functional.interpolate(
|
||||||
last_frame[None].cpu(),
|
last_frame[None].cpu(),
|
||||||
size=(first_frame_h, first_frame_w),
|
size=(first_frame_h, first_frame_w),
|
||||||
mode='bicubic'
|
mode='bicubic').transpose(0, 1),
|
||||||
).transpose(0, 1),
|
],
|
||||||
], dim=1).to(self.device)
|
dim=1).to(self.device)
|
||||||
])[0]
|
])[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -103,11 +106,12 @@ class WanI2V:
|
|||||||
init_on_cpu = False
|
init_on_cpu = False
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -196,8 +200,7 @@ class WanI2V:
|
|||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
seed_g.manual_seed(seed)
|
seed_g.manual_seed(seed)
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
16,
|
16, (F - 1) // 4 + 1,
|
||||||
(F - 1) // 4 + 1,
|
|
||||||
lat_h,
|
lat_h,
|
||||||
lat_w,
|
lat_w,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
@ -357,7 +357,8 @@ class MLPProj(torch.nn.Module):
|
|||||||
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||||
torch.nn.LayerNorm(out_dim))
|
torch.nn.LayerNorm(out_dim))
|
||||||
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
||||||
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
self.emb_pos = nn.Parameter(
|
||||||
|
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
if hasattr(self, 'emb_pos'):
|
if hasattr(self, 'emb_pos'):
|
||||||
|
@ -3,12 +3,13 @@ import torch
|
|||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.configuration_utils import register_to_config
|
from diffusers.configuration_utils import register_to_config
|
||||||
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
|
||||||
|
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
cross_attn_type,
|
cross_attn_type,
|
||||||
dim,
|
dim,
|
||||||
ffn_dim,
|
ffn_dim,
|
||||||
@ -17,9 +18,9 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
block_id=0
|
block_id=0):
|
||||||
):
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
qk_norm, cross_attn_norm, eps)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
if block_id == 0:
|
if block_id == 0:
|
||||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||||
@ -39,8 +40,8 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
|
|
||||||
|
|
||||||
class BaseWanAttentionBlock(WanAttentionBlock):
|
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
cross_attn_type,
|
cross_attn_type,
|
||||||
dim,
|
dim,
|
||||||
ffn_dim,
|
ffn_dim,
|
||||||
@ -49,9 +50,9 @@ class BaseWanAttentionBlock(WanAttentionBlock):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
block_id=None
|
block_id=None):
|
||||||
):
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
qk_norm, cross_attn_norm, eps)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
|
|
||||||
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
||||||
@ -62,6 +63,7 @@ class BaseWanAttentionBlock(WanAttentionBlock):
|
|||||||
|
|
||||||
|
|
||||||
class VaceWanModel(WanModel):
|
class VaceWanModel(WanModel):
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vace_layers=None,
|
vace_layers=None,
|
||||||
@ -81,42 +83,57 @@ class VaceWanModel(WanModel):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6):
|
eps=1e-6):
|
||||||
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
|
||||||
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
freq_dim, text_dim, out_dim, num_heads, num_layers,
|
||||||
|
window_size, qk_norm, cross_attn_norm, eps)
|
||||||
|
|
||||||
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
self.vace_layers = [i for i in range(0, self.num_layers, 2)
|
||||||
|
] if vace_layers is None else vace_layers
|
||||||
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
||||||
|
|
||||||
assert 0 in self.vace_layers
|
assert 0 in self.vace_layers
|
||||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
self.vace_layers_mapping = {
|
||||||
|
i: n for n, i in enumerate(self.vace_layers)
|
||||||
|
}
|
||||||
|
|
||||||
# blocks
|
# blocks
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
BaseWanAttentionBlock(
|
||||||
self.cross_attn_norm, self.eps,
|
't2v_cross_attn',
|
||||||
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.window_size,
|
||||||
|
self.qk_norm,
|
||||||
|
self.cross_attn_norm,
|
||||||
|
self.eps,
|
||||||
|
block_id=self.vace_layers_mapping[i]
|
||||||
|
if i in self.vace_layers else None)
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
# vace blocks
|
# vace blocks
|
||||||
self.vace_blocks = nn.ModuleList([
|
self.vace_blocks = nn.ModuleList([
|
||||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
VaceWanAttentionBlock(
|
||||||
self.cross_attn_norm, self.eps, block_id=i)
|
't2v_cross_attn',
|
||||||
for i in self.vace_layers
|
self.dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.window_size,
|
||||||
|
self.qk_norm,
|
||||||
|
self.cross_attn_norm,
|
||||||
|
self.eps,
|
||||||
|
block_id=i) for i in self.vace_layers
|
||||||
])
|
])
|
||||||
|
|
||||||
# vace patch embeddings
|
# vace patch embeddings
|
||||||
self.vace_patch_embedding = nn.Conv3d(
|
self.vace_patch_embedding = nn.Conv3d(
|
||||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
self.vace_in_dim,
|
||||||
)
|
self.dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size)
|
||||||
|
|
||||||
def forward_vace(
|
def forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
vace_context,
|
|
||||||
seq_len,
|
|
||||||
kwargs
|
|
||||||
):
|
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -85,11 +88,12 @@ class WanT2V:
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
usp_dit_forward)
|
usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
from .fm_solvers import (
|
||||||
retrieve_timesteps)
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from .vace_processor import VaceVideoProcessor
|
from .vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
|
@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
from diffusers.schedulers.scheduling_utils import (
|
||||||
|
KarrasDiffusionSchedulers,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
SchedulerOutput)
|
SchedulerOutput,
|
||||||
|
)
|
||||||
from diffusers.utils import deprecate, is_scipy_available
|
from diffusers.utils import deprecate, is_scipy_available
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
|
||||||
|
@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
from diffusers.schedulers.scheduling_utils import (
|
||||||
|
KarrasDiffusionSchedulers,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
SchedulerOutput)
|
SchedulerOutput,
|
||||||
|
)
|
||||||
from diffusers.utils import deprecate, is_scipy_available
|
from diffusers.utils import deprecate, is_scipy_available
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
|
@ -7,7 +7,7 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Union, List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import dashscope
|
import dashscope
|
||||||
import torch
|
import torch
|
||||||
@ -96,7 +96,6 @@ VL_EN_SYS_PROMPT = \
|
|||||||
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
||||||
'''Directly output the rewritten English text.'''
|
'''Directly output the rewritten English text.'''
|
||||||
|
|
||||||
|
|
||||||
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写
|
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写
|
||||||
任务要求:
|
任务要求:
|
||||||
1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写
|
1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写
|
||||||
@ -198,8 +197,8 @@ class PromptExpander:
|
|||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = self.decide_system_prompt(
|
system_prompt = self.decide_system_prompt(
|
||||||
tar_lang=tar_lang,
|
tar_lang=tar_lang,
|
||||||
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
|
multi_images_input=isinstance(image, (list, tuple)) and
|
||||||
)
|
len(image) > 1)
|
||||||
if seed < 0:
|
if seed < 0:
|
||||||
seed = random.randint(0, sys.maxsize)
|
seed = random.randint(0, sys.maxsize)
|
||||||
if image is not None and self.is_vl:
|
if image is not None and self.is_vl:
|
||||||
@ -289,7 +288,8 @@ class DashScopePromptExpander(PromptExpander):
|
|||||||
def extend_with_img(self,
|
def extend_with_img(self,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
|
image: Union[List[Image.Image], List[str], Image.Image,
|
||||||
|
str] = None,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -308,13 +308,15 @@ class DashScopePromptExpander(PromptExpander):
|
|||||||
_image.save(f.name)
|
_image.save(f.name)
|
||||||
image_path = f"file://{f.name}"
|
image_path = f"file://{f.name}"
|
||||||
return image_path
|
return image_path
|
||||||
|
|
||||||
if not isinstance(image, (list, tuple)):
|
if not isinstance(image, (list, tuple)):
|
||||||
image = [image]
|
image = [image]
|
||||||
image_path_list = [ensure_image(_image) for _image in image]
|
image_path_list = [ensure_image(_image) for _image in image]
|
||||||
role_content = [
|
role_content = [{
|
||||||
{"text": prompt},
|
"text": prompt
|
||||||
*[{"image": image_path} for image_path in image_path_list]
|
}, *[{
|
||||||
]
|
"image": image_path
|
||||||
|
} for image_path in image_path_list]]
|
||||||
system_content = [{"text": system_prompt}]
|
system_content = [{"text": system_prompt}]
|
||||||
prompt = f"{prompt}"
|
prompt = f"{prompt}"
|
||||||
messages = [
|
messages = [
|
||||||
@ -393,8 +395,11 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
|
|
||||||
if self.is_vl:
|
if self.is_vl:
|
||||||
# default: Load the model on the available device(s)
|
# default: Load the model on the available device(s)
|
||||||
from transformers import (AutoProcessor, AutoTokenizer,
|
from transformers import (
|
||||||
Qwen2_5_VLForConditionalGeneration)
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from .qwen_vl_utils import process_vision_info
|
from .qwen_vl_utils import process_vision_info
|
||||||
except:
|
except:
|
||||||
@ -459,7 +464,8 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
def extend_with_img(self,
|
def extend_with_img(self,
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
|
image: Union[List[Image.Image], List[str], Image.Image,
|
||||||
|
str] = None,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -468,26 +474,19 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
if not isinstance(image, (list, tuple)):
|
if not isinstance(image, (list, tuple)):
|
||||||
image = [image]
|
image = [image]
|
||||||
|
|
||||||
system_content = [{
|
system_content = [{"type": "text", "text": system_prompt}]
|
||||||
"type": "text",
|
role_content = [{
|
||||||
"text": system_prompt
|
|
||||||
}]
|
|
||||||
role_content = [
|
|
||||||
{
|
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": prompt
|
"text": prompt
|
||||||
},
|
}, *[{
|
||||||
*[
|
"image": image_path
|
||||||
{"image": image_path} for image_path in image
|
} for image_path in image]]
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
'role': 'system',
|
'role': 'system',
|
||||||
'content': system_content,
|
'content': system_content,
|
||||||
}, {
|
}, {
|
||||||
"role":
|
"role": "user",
|
||||||
"user",
|
|
||||||
"content": role_content,
|
"content": role_content,
|
||||||
}]
|
}]
|
||||||
|
|
||||||
@ -611,25 +610,38 @@ if __name__ == "__main__":
|
|||||||
print("VL qwen vl en result -> en",
|
print("VL qwen vl en result -> en",
|
||||||
qwen_result.prompt) # , qwen_result.system_prompt)
|
qwen_result.prompt) # , qwen_result.system_prompt)
|
||||||
# test multi images
|
# test multi images
|
||||||
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
|
image = [
|
||||||
|
"./examples/flf2v_input_first_frame.png",
|
||||||
|
"./examples/flf2v_input_last_frame.png"
|
||||||
|
]
|
||||||
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
|
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
|
||||||
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
|
en_prompt = (
|
||||||
|
"Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
|
||||||
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
|
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
|
||||||
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
|
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
|
||||||
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
|
"architectural structures, combining to create a tranquil and breathtaking coastal landscape."
|
||||||
|
)
|
||||||
|
|
||||||
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
|
dashscope_prompt_expander = DashScopePromptExpander(
|
||||||
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=ds_model_name, is_vl=True)
|
||||||
|
dashscope_result = dashscope_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL dashscope result -> zh", dashscope_result.prompt)
|
print("VL dashscope result -> zh", dashscope_result.prompt)
|
||||||
|
|
||||||
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
|
dashscope_prompt_expander = DashScopePromptExpander(
|
||||||
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=ds_model_name, is_vl=True)
|
||||||
|
dashscope_result = dashscope_prompt_expander(
|
||||||
|
en_prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL dashscope en result -> zh", dashscope_result.prompt)
|
print("VL dashscope en result -> zh", dashscope_result.prompt)
|
||||||
|
|
||||||
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
|
qwen_prompt_expander = QwenPromptExpander(
|
||||||
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=qwen_model_name, is_vl=True, device=0)
|
||||||
|
qwen_result = qwen_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL qwen result -> zh", qwen_result.prompt)
|
print("VL qwen result -> zh", qwen_result.prompt)
|
||||||
|
|
||||||
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
|
qwen_prompt_expander = QwenPromptExpander(
|
||||||
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
|
model_name=qwen_model_name, is_vl=True, device=0)
|
||||||
|
qwen_result = qwen_prompt_expander(
|
||||||
|
prompt, tar_lang="zh", image=image, seed=seed)
|
||||||
print("VL qwen en result -> zh", qwen_result.prompt)
|
print("VL qwen en result -> zh", qwen_result.prompt)
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class VaceImageProcessor(object):
|
class VaceImageProcessor(object):
|
||||||
|
|
||||||
def __init__(self, downsample=None, seq_len=None):
|
def __init__(self, downsample=None, seq_len=None):
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
@ -16,7 +17,8 @@ class VaceImageProcessor(object):
|
|||||||
if image.mode == 'P':
|
if image.mode == 'P':
|
||||||
image = image.convert(f'{cvt_type}A')
|
image = image.convert(f'{cvt_type}A')
|
||||||
if image.mode == f'{cvt_type}A':
|
if image.mode == f'{cvt_type}A':
|
||||||
bg = Image.new(cvt_type,
|
bg = Image.new(
|
||||||
|
cvt_type,
|
||||||
size=(image.width, image.height),
|
size=(image.width, image.height),
|
||||||
color=(255, 255, 255))
|
color=(255, 255, 255))
|
||||||
bg.paste(image, (0, 0), mask=image)
|
bg.paste(image, (0, 0), mask=image)
|
||||||
@ -41,10 +43,8 @@ class VaceImageProcessor(object):
|
|||||||
if iw != ow or ih != oh:
|
if iw != ow or ih != oh:
|
||||||
# resize
|
# resize
|
||||||
scale = max(ow / iw, oh / ih)
|
scale = max(ow / iw, oh / ih)
|
||||||
img = img.resize(
|
img = img.resize((round(scale * iw), round(scale * ih)),
|
||||||
(round(scale * iw), round(scale * ih)),
|
resample=Image.Resampling.LANCZOS)
|
||||||
resample=Image.Resampling.LANCZOS
|
|
||||||
)
|
|
||||||
assert img.width >= ow and img.height >= oh
|
assert img.width >= ow and img.height >= oh
|
||||||
|
|
||||||
# center crop
|
# center crop
|
||||||
@ -66,7 +66,11 @@ class VaceImageProcessor(object):
|
|||||||
def load_image_pair(self, data_key, data_key2, **kwargs):
|
def load_image_pair(self, data_key, data_key2, **kwargs):
|
||||||
return self.load_image_batch(data_key, data_key2, **kwargs)
|
return self.load_image_batch(data_key, data_key2, **kwargs)
|
||||||
|
|
||||||
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
def load_image_batch(self,
|
||||||
|
*data_key_batch,
|
||||||
|
normalize=True,
|
||||||
|
seq_len=None,
|
||||||
|
**kwargs):
|
||||||
seq_len = self.seq_len if seq_len is None else seq_len
|
seq_len = self.seq_len if seq_len is None else seq_len
|
||||||
imgs = []
|
imgs = []
|
||||||
for data_key in data_key_batch:
|
for data_key in data_key_batch:
|
||||||
@ -85,7 +89,9 @@ class VaceImageProcessor(object):
|
|||||||
|
|
||||||
|
|
||||||
class VaceVideoProcessor(object):
|
class VaceVideoProcessor(object):
|
||||||
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
|
||||||
|
def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
|
||||||
|
zero_start, seq_len, keep_last, **kwargs):
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.min_area = min_area
|
self.min_area = min_area
|
||||||
self.max_area = max_area
|
self.max_area = max_area
|
||||||
@ -130,8 +136,7 @@ class VaceVideoProcessor(object):
|
|||||||
video,
|
video,
|
||||||
size=(round(scale * ih), round(scale * iw)),
|
size=(round(scale * ih), round(scale * iw)),
|
||||||
mode='bicubic',
|
mode='bicubic',
|
||||||
antialias=True
|
antialias=True)
|
||||||
)
|
|
||||||
assert video.size(3) >= ow and video.size(2) >= oh
|
assert video.size(3) >= ow and video.size(2) >= oh
|
||||||
|
|
||||||
# center crop
|
# center crop
|
||||||
@ -146,7 +151,8 @@ class VaceVideoProcessor(object):
|
|||||||
def _video_preprocess(self, video, oh, ow):
|
def _video_preprocess(self, video, oh, ow):
|
||||||
return self.resize_crop(video, oh, ow)
|
return self.resize_crop(video, oh, ow)
|
||||||
|
|
||||||
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
|
||||||
|
rng):
|
||||||
target_fps = min(fps, self.max_fps)
|
target_fps = min(fps, self.max_fps)
|
||||||
duration = frame_timestamps[-1].mean()
|
duration = frame_timestamps[-1].mean()
|
||||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
@ -154,11 +160,10 @@ class VaceVideoProcessor(object):
|
|||||||
ratio = h / w
|
ratio = h / w
|
||||||
df, dh, dw = self.downsample
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
area_z = min(self.seq_len, self.max_area / (dh * dw),
|
||||||
of = min(
|
(h // dh) * (w // dw))
|
||||||
(int(duration * target_fps) - 1) // df + 1,
|
of = min((int(duration * target_fps) - 1) // df + 1,
|
||||||
int(self.seq_len / area_z)
|
int(self.seq_len / area_z))
|
||||||
)
|
|
||||||
|
|
||||||
# deduce target shape of the [latent video]
|
# deduce target shape of the [latent video]
|
||||||
target_area_z = min(area_z, int(self.seq_len / of))
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
@ -170,26 +175,27 @@ class VaceVideoProcessor(object):
|
|||||||
|
|
||||||
# sample frame ids
|
# sample frame ids
|
||||||
target_duration = of / target_fps
|
target_duration = of / target_fps
|
||||||
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
begin = 0. if self.zero_start else rng.uniform(
|
||||||
|
0, duration - target_duration)
|
||||||
timestamps = np.linspace(begin, begin + target_duration, of)
|
timestamps = np.linspace(begin, begin + target_duration, of)
|
||||||
frame_ids = np.argmax(np.logical_and(
|
frame_ids = np.argmax(
|
||||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
timestamps[:, None] < frame_timestamps[None, :, 1]
|
timestamps[:, None] < frame_timestamps[None, :, 1]),
|
||||||
), axis=1).tolist()
|
axis=1).tolist()
|
||||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
|
||||||
|
crop_box, rng):
|
||||||
duration = frame_timestamps[-1].mean()
|
duration = frame_timestamps[-1].mean()
|
||||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
h, w = y2 - y1, x2 - x1
|
h, w = y2 - y1, x2 - x1
|
||||||
ratio = h / w
|
ratio = h / w
|
||||||
df, dh, dw = self.downsample
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
area_z = min(self.seq_len, self.max_area / (dh * dw),
|
||||||
of = min(
|
(h // dh) * (w // dw))
|
||||||
(len(frame_timestamps) - 1) // df + 1,
|
of = min((len(frame_timestamps) - 1) // df + 1,
|
||||||
int(self.seq_len / area_z)
|
int(self.seq_len / area_z))
|
||||||
)
|
|
||||||
|
|
||||||
# deduce target shape of the [latent video]
|
# deduce target shape of the [latent video]
|
||||||
target_area_z = min(area_z, int(self.seq_len / of))
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
@ -203,27 +209,39 @@ class VaceVideoProcessor(object):
|
|||||||
target_duration = duration
|
target_duration = duration
|
||||||
target_fps = of / target_duration
|
target_fps = of / target_duration
|
||||||
timestamps = np.linspace(0., target_duration, of)
|
timestamps = np.linspace(0., target_duration, of)
|
||||||
frame_ids = np.argmax(np.logical_and(
|
frame_ids = np.argmax(
|
||||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
timestamps[:, None] <= frame_timestamps[None, :, 1]
|
timestamps[:, None] <= frame_timestamps[None, :, 1]),
|
||||||
), axis=1).tolist()
|
axis=1).tolist()
|
||||||
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
||||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
|
|
||||||
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||||
if self.keep_last:
|
if self.keep_last:
|
||||||
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
|
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
|
||||||
|
w, crop_box, rng)
|
||||||
else:
|
else:
|
||||||
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
|
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
|
||||||
|
crop_box, rng)
|
||||||
|
|
||||||
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||||
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
return self.load_video_batch(
|
||||||
|
data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
def load_video_pair(self,
|
||||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
data_key,
|
||||||
|
data_key2,
|
||||||
|
crop_box=None,
|
||||||
|
seed=2024,
|
||||||
|
**kwargs):
|
||||||
|
return self.load_video_batch(
|
||||||
|
data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
|
def load_video_batch(self,
|
||||||
|
*data_key_batch,
|
||||||
|
crop_box=None,
|
||||||
|
seed=2024,
|
||||||
|
**kwargs):
|
||||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||||
# read video
|
# read video
|
||||||
import decord
|
import decord
|
||||||
@ -235,36 +253,53 @@ class VaceVideoProcessor(object):
|
|||||||
|
|
||||||
fps = readers[0].get_avg_fps()
|
fps = readers[0].get_avg_fps()
|
||||||
length = min([len(r) for r in readers])
|
length = min([len(r) for r in readers])
|
||||||
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
frame_timestamps = [
|
||||||
|
readers[0].get_frame_timestamp(i) for i in range(length)
|
||||||
|
]
|
||||||
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||||
h, w = readers[0].next().shape[:2]
|
h, w = readers[0].next().shape[:2]
|
||||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
|
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
|
||||||
|
fps, frame_timestamps, h, w, crop_box, rng)
|
||||||
|
|
||||||
# preprocess video
|
# preprocess video
|
||||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
videos = [
|
||||||
|
reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
|
||||||
|
for reader in readers
|
||||||
|
]
|
||||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||||
return *videos, frame_ids, (oh, ow), fps
|
return *videos, frame_ids, (oh, ow), fps
|
||||||
# return videos if len(videos) > 1 else videos[0]
|
# return videos if len(videos) > 1 else videos[0]
|
||||||
|
|
||||||
|
|
||||||
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
|
||||||
|
device):
|
||||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
if sub_src_video is None and sub_src_mask is None:
|
if sub_src_video is None and sub_src_mask is None:
|
||||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
src_video[i] = torch.zeros(
|
||||||
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
(3, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
|
src_mask[i] = torch.ones(
|
||||||
|
(1, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
for i, ref_images in enumerate(src_ref_images):
|
for i, ref_images in enumerate(src_ref_images):
|
||||||
if ref_images is not None:
|
if ref_images is not None:
|
||||||
for j, ref_img in enumerate(ref_images):
|
for j, ref_img in enumerate(ref_images):
|
||||||
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
||||||
canvas_height, canvas_width = image_size
|
canvas_height, canvas_width = image_size
|
||||||
ref_height, ref_width = ref_img.shape[-2:]
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
white_canvas = torch.ones(
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
(3, 1, canvas_height, canvas_width),
|
||||||
|
device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height,
|
||||||
|
canvas_width / ref_width)
|
||||||
new_height = int(ref_height * scale)
|
new_height = int(ref_height * scale)
|
||||||
new_width = int(ref_width * scale)
|
new_width = int(ref_width * scale)
|
||||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
resized_image = F.interpolate(
|
||||||
|
ref_img.squeeze(1).unsqueeze(0),
|
||||||
|
size=(new_height, new_width),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
top = (canvas_height - new_height) // 2
|
top = (canvas_height - new_height) // 2
|
||||||
left = (canvas_width - new_width) // 2
|
left = (canvas_width - new_width) // 2
|
||||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
white_canvas[:, :, top:top + new_height,
|
||||||
|
left:left + new_width] = resized_image
|
||||||
src_ref_images[i][j] = white_canvas
|
src_ref_images[i][j] = white_canvas
|
||||||
return src_video, src_mask, src_ref_images
|
return src_video, src_mask, src_ref_images
|
||||||
|
241
wan/vace.py
241
wan/vace.py
@ -1,32 +1,41 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import types
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
|
||||||
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
|
||||||
from .modules.vace_model import VaceWanModel
|
from .modules.vace_model import VaceWanModel
|
||||||
|
from .text2video import (
|
||||||
|
FlowDPMSolverMultistepScheduler,
|
||||||
|
FlowUniPCMultistepScheduler,
|
||||||
|
T5EncoderModel,
|
||||||
|
WanT2V,
|
||||||
|
WanVAE,
|
||||||
|
get_sampling_sigmas,
|
||||||
|
retrieve_timesteps,
|
||||||
|
shard_model,
|
||||||
|
)
|
||||||
from .utils.vace_processor import VaceVideoProcessor
|
from .utils.vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
class WanVace(WanT2V):
|
class WanVace(WanT2V):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@ -87,12 +96,13 @@ class WanVace(WanT2V):
|
|||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from xfuser.core.distributed import \
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
from .distributed.xdit_context_parallel import (
|
||||||
|
usp_attn_forward,
|
||||||
usp_dit_forward,
|
usp_dit_forward,
|
||||||
usp_dit_forward_vace)
|
usp_dit_forward_vace,
|
||||||
|
)
|
||||||
for block in self.model.blocks:
|
for block in self.model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -100,7 +110,8 @@ class WanVace(WanT2V):
|
|||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
||||||
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
|
self.model.forward_vace = types.MethodType(usp_dit_forward_vace,
|
||||||
|
self.model)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
else:
|
else:
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
@ -114,7 +125,9 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
self.vid_proc = VaceVideoProcessor(
|
||||||
|
downsample=tuple(
|
||||||
|
[x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||||
min_area=720 * 1280,
|
min_area=720 * 1280,
|
||||||
max_area=720 * 1280,
|
max_area=720 * 1280,
|
||||||
min_fps=config.sample_fps,
|
min_fps=config.sample_fps,
|
||||||
@ -138,7 +151,9 @@ class WanVace(WanT2V):
|
|||||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||||
inactive = vae.encode(inactive)
|
inactive = vae.encode(inactive)
|
||||||
reactive = vae.encode(reactive)
|
reactive = vae.encode(reactive)
|
||||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
latents = [
|
||||||
|
torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)
|
||||||
|
]
|
||||||
|
|
||||||
cat_latents = []
|
cat_latents = []
|
||||||
for latent, refs in zip(latents, ref_images):
|
for latent, refs in zip(latents, ref_images):
|
||||||
@ -147,7 +162,10 @@ class WanVace(WanT2V):
|
|||||||
ref_latent = vae.encode(refs)
|
ref_latent = vae.encode(refs)
|
||||||
else:
|
else:
|
||||||
ref_latent = vae.encode(refs)
|
ref_latent = vae.encode(refs)
|
||||||
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
ref_latent = [
|
||||||
|
torch.cat((u, torch.zeros_like(u)), dim=0)
|
||||||
|
for u in ref_latent
|
||||||
|
]
|
||||||
assert all([x.shape[1] == 1 for x in ref_latent])
|
assert all([x.shape[1] == 1 for x in ref_latent])
|
||||||
latent = torch.cat([*ref_latent, latent], dim=1)
|
latent = torch.cat([*ref_latent, latent], dim=1)
|
||||||
cat_latents.append(latent)
|
cat_latents.append(latent)
|
||||||
@ -169,16 +187,17 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
# reshape
|
# reshape
|
||||||
mask = mask[0, :, :, :]
|
mask = mask[0, :, :, :]
|
||||||
mask = mask.view(
|
mask = mask.view(depth, height, vae_stride[1], width,
|
||||||
depth, height, vae_stride[1], width, vae_stride[1]
|
vae_stride[1]) # depth, height, 8, width, 8
|
||||||
) # depth, height, 8, width, 8
|
|
||||||
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
||||||
mask = mask.reshape(
|
mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height,
|
||||||
vae_stride[1] * vae_stride[2], depth, height, width
|
width) # 8*8, depth, height, width
|
||||||
) # 8*8, depth, height, width
|
|
||||||
|
|
||||||
# interpolation
|
# interpolation
|
||||||
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
mask = F.interpolate(
|
||||||
|
mask.unsqueeze(0),
|
||||||
|
size=(new_depth, height, width),
|
||||||
|
mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
if refs is not None:
|
if refs is not None:
|
||||||
length = len(refs)
|
length = len(refs)
|
||||||
@ -190,7 +209,8 @@ class WanVace(WanT2V):
|
|||||||
def vace_latent(self, z, m):
|
def vace_latent(self, z, m):
|
||||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||||
|
|
||||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
|
||||||
|
image_size, device):
|
||||||
area = image_size[0] * image_size[1]
|
area = image_size[0] * image_size[1]
|
||||||
self.vid_proc.set_area(area)
|
self.vid_proc.set_area(area)
|
||||||
if area == 720 * 1280:
|
if area == 720 * 1280:
|
||||||
@ -198,19 +218,26 @@ class WanVace(WanT2V):
|
|||||||
elif area == 480 * 832:
|
elif area == 480 * 832:
|
||||||
self.vid_proc.set_seq_len(32760)
|
self.vid_proc.set_seq_len(32760)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'image_size {image_size} is not supported')
|
raise NotImplementedError(
|
||||||
|
f'image_size {image_size} is not supported')
|
||||||
|
|
||||||
image_size = (image_size[1], image_size[0])
|
image_size = (image_size[1], image_size[0])
|
||||||
image_sizes = []
|
image_sizes = []
|
||||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
for i, (sub_src_video,
|
||||||
|
sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
if sub_src_mask is not None and sub_src_video is not None:
|
if sub_src_mask is not None and sub_src_video is not None:
|
||||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
|
src_video[i], src_mask[
|
||||||
|
i], _, _, _ = self.vid_proc.load_video_pair(
|
||||||
|
sub_src_video, sub_src_mask)
|
||||||
src_video[i] = src_video[i].to(device)
|
src_video[i] = src_video[i].to(device)
|
||||||
src_mask[i] = src_mask[i].to(device)
|
src_mask[i] = src_mask[i].to(device)
|
||||||
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
src_mask[i] = torch.clamp(
|
||||||
|
(src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||||
image_sizes.append(src_video[i].shape[2:])
|
image_sizes.append(src_video[i].shape[2:])
|
||||||
elif sub_src_video is None:
|
elif sub_src_video is None:
|
||||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
src_video[i] = torch.zeros(
|
||||||
|
(3, num_frames, image_size[0], image_size[1]),
|
||||||
|
device=device)
|
||||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||||
image_sizes.append(image_size)
|
image_sizes.append(image_size)
|
||||||
else:
|
else:
|
||||||
@ -225,18 +252,27 @@ class WanVace(WanT2V):
|
|||||||
for j, ref_img in enumerate(ref_images):
|
for j, ref_img in enumerate(ref_images):
|
||||||
if ref_img is not None:
|
if ref_img is not None:
|
||||||
ref_img = Image.open(ref_img).convert("RGB")
|
ref_img = Image.open(ref_img).convert("RGB")
|
||||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(
|
||||||
|
0.5).unsqueeze(1)
|
||||||
if ref_img.shape[-2:] != image_size:
|
if ref_img.shape[-2:] != image_size:
|
||||||
canvas_height, canvas_width = image_size
|
canvas_height, canvas_width = image_size
|
||||||
ref_height, ref_width = ref_img.shape[-2:]
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
white_canvas = torch.ones(
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
(3, 1, canvas_height, canvas_width),
|
||||||
|
device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height,
|
||||||
|
canvas_width / ref_width)
|
||||||
new_height = int(ref_height * scale)
|
new_height = int(ref_height * scale)
|
||||||
new_width = int(ref_width * scale)
|
new_width = int(ref_width * scale)
|
||||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
resized_image = F.interpolate(
|
||||||
|
ref_img.squeeze(1).unsqueeze(0),
|
||||||
|
size=(new_height, new_width),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
top = (canvas_height - new_height) // 2
|
top = (canvas_height - new_height) // 2
|
||||||
left = (canvas_width - new_width) // 2
|
left = (canvas_width - new_width) // 2
|
||||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
white_canvas[:, :, top:top + new_height,
|
||||||
|
left:left + new_width] = resized_image
|
||||||
ref_img = white_canvas
|
ref_img = white_canvas
|
||||||
src_ref_images[i][j] = ref_img.to(device)
|
src_ref_images[i][j] = ref_img.to(device)
|
||||||
return src_video, src_mask, src_ref_images
|
return src_video, src_mask, src_ref_images
|
||||||
@ -256,8 +292,6 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
return vae.decode(trimed_zs)
|
return vae.decode(trimed_zs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames,
|
input_frames,
|
||||||
@ -335,7 +369,8 @@ class WanVace(WanT2V):
|
|||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
# vace context encode
|
# vace context encode
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
|
z0 = self.vace_encode_frames(
|
||||||
|
input_frames, input_ref_images, masks=input_masks)
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
z = self.vace_latent(z0, m0)
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
@ -399,9 +434,17 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_c)[0]
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_null)[0]
|
||||||
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond)
|
noise_pred_cond - noise_pred_uncond)
|
||||||
@ -433,14 +476,13 @@ class WanVace(WanT2V):
|
|||||||
|
|
||||||
|
|
||||||
class WanVaceMP(WanVace):
|
class WanVaceMP(WanVace):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
ulysses_size=None,
|
ulysses_size=None,
|
||||||
ring_size=None
|
ring_size=None):
|
||||||
):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
@ -457,7 +499,8 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
||||||
self.vid_proc = VaceVideoProcessor(
|
self.vid_proc = VaceVideoProcessor(
|
||||||
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
downsample=tuple(
|
||||||
|
[x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
||||||
min_area=480 * 832,
|
min_area=480 * 832,
|
||||||
max_area=480 * 832,
|
max_area=480 * 832,
|
||||||
min_fps=self.config.sample_fps,
|
min_fps=self.config.sample_fps,
|
||||||
@ -466,20 +509,30 @@ class WanVaceMP(WanVace):
|
|||||||
seq_len=32760,
|
seq_len=32760,
|
||||||
keep_last=True)
|
keep_last=True)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_load(self):
|
def dynamic_load(self):
|
||||||
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
||||||
return
|
return
|
||||||
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
gpu_infer = os.environ.get(
|
||||||
|
'LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
||||||
pmi_rank = int(os.environ['RANK'])
|
pmi_rank = int(os.environ['RANK'])
|
||||||
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
||||||
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
|
in_q_list = [
|
||||||
|
torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)
|
||||||
|
]
|
||||||
out_q = torch.multiprocessing.Manager().Queue()
|
out_q = torch.multiprocessing.Manager().Queue()
|
||||||
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
|
initialized_events = [
|
||||||
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
|
torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)
|
||||||
|
]
|
||||||
|
context = mp.spawn(
|
||||||
|
self.mp_worker,
|
||||||
|
nprocs=gpu_infer,
|
||||||
|
args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q,
|
||||||
|
initialized_events, self),
|
||||||
|
join=False)
|
||||||
all_initialized = False
|
all_initialized = False
|
||||||
while not all_initialized:
|
while not all_initialized:
|
||||||
all_initialized = all(event.is_set() for event in initialized_events)
|
all_initialized = all(
|
||||||
|
event.is_set() for event in initialized_events)
|
||||||
if not all_initialized:
|
if not all_initialized:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
print('Inference model is initialized', flush=True)
|
print('Inference model is initialized', flush=True)
|
||||||
@ -495,12 +548,19 @@ class WanVaceMP(WanVace):
|
|||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
|
data = [
|
||||||
|
self.transfer_data_to_cuda(subdata, device)
|
||||||
|
for subdata in data
|
||||||
|
]
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
|
data = {
|
||||||
|
key: self.transfer_data_to_cuda(val, device)
|
||||||
|
for key, val in data.items()
|
||||||
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
|
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
|
||||||
|
out_q, initialized_events, work_env):
|
||||||
try:
|
try:
|
||||||
world_size = pmi_world_size * gpu_infer
|
world_size = pmi_world_size * gpu_infer
|
||||||
rank = pmi_rank * gpu_infer + gpu
|
rank = pmi_rank * gpu_infer + gpu
|
||||||
@ -511,19 +571,19 @@ class WanVaceMP(WanVace):
|
|||||||
backend='nccl',
|
backend='nccl',
|
||||||
init_method='env://',
|
init_method='env://',
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size
|
world_size=world_size)
|
||||||
)
|
|
||||||
|
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
from xfuser.core.distributed import (
|
||||||
init_distributed_environment)
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
initialize_model_parallel(
|
initialize_model_parallel(
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
ring_degree=self.ring_size or 1,
|
ring_degree=self.ring_size or 1,
|
||||||
ulysses_degree=self.ulysses_size or 1
|
ulysses_degree=self.ulysses_size or 1)
|
||||||
)
|
|
||||||
|
|
||||||
num_train_timesteps = self.config.num_train_timesteps
|
num_train_timesteps = self.config.num_train_timesteps
|
||||||
param_dtype = self.config.param_dtype
|
param_dtype = self.config.param_dtype
|
||||||
@ -532,14 +592,17 @@ class WanVaceMP(WanVace):
|
|||||||
text_len=self.config.text_len,
|
text_len=self.config.text_len,
|
||||||
dtype=self.config.t5_dtype,
|
dtype=self.config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
|
checkpoint_path=os.path.join(self.checkpoint_dir,
|
||||||
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
|
self.config.t5_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(self.checkpoint_dir,
|
||||||
|
self.config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if True else None)
|
shard_fn=shard_fn if True else None)
|
||||||
text_encoder.model.to(gpu)
|
text_encoder.model.to(gpu)
|
||||||
vae_stride = self.config.vae_stride
|
vae_stride = self.config.vae_stride
|
||||||
patch_size = self.config.patch_size
|
patch_size = self.config.patch_size
|
||||||
vae = WanVAE(
|
vae = WanVAE(
|
||||||
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
|
vae_pth=os.path.join(self.checkpoint_dir,
|
||||||
|
self.config.vae_checkpoint),
|
||||||
device=gpu)
|
device=gpu)
|
||||||
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
||||||
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
||||||
@ -547,9 +610,12 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
if self.use_usp:
|
if self.use_usp:
|
||||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
|
||||||
|
from .distributed.xdit_context_parallel import (
|
||||||
|
usp_attn_forward,
|
||||||
usp_dit_forward,
|
usp_dit_forward,
|
||||||
usp_dit_forward_vace)
|
usp_dit_forward_vace,
|
||||||
|
)
|
||||||
for block in model.blocks:
|
for block in model.blocks:
|
||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
@ -557,7 +623,8 @@ class WanVaceMP(WanVace):
|
|||||||
block.self_attn.forward = types.MethodType(
|
block.self_attn.forward = types.MethodType(
|
||||||
usp_attn_forward, block.self_attn)
|
usp_attn_forward, block.self_attn)
|
||||||
model.forward = types.MethodType(usp_dit_forward, model)
|
model.forward = types.MethodType(usp_dit_forward, model)
|
||||||
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
|
model.forward_vace = types.MethodType(usp_dit_forward_vace,
|
||||||
|
model)
|
||||||
sp_size = get_sequence_parallel_world_size()
|
sp_size = get_sequence_parallel_world_size()
|
||||||
else:
|
else:
|
||||||
sp_size = 1
|
sp_size = 1
|
||||||
@ -577,7 +644,8 @@ class WanVaceMP(WanVace):
|
|||||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
||||||
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
||||||
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
||||||
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
|
input_ref_images = self.transfer_data_to_cuda(
|
||||||
|
input_ref_images, gpu)
|
||||||
|
|
||||||
if n_prompt == "":
|
if n_prompt == "":
|
||||||
n_prompt = sample_neg_prompt
|
n_prompt = sample_neg_prompt
|
||||||
@ -589,8 +657,10 @@ class WanVaceMP(WanVace):
|
|||||||
context_null = text_encoder([n_prompt], gpu)
|
context_null = text_encoder([n_prompt], gpu)
|
||||||
|
|
||||||
# vace context encode
|
# vace context encode
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
|
z0 = self.vace_encode_frames(
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
|
input_frames, input_ref_images, masks=input_masks, vae=vae)
|
||||||
|
m0 = self.vace_encode_masks(
|
||||||
|
input_masks, input_ref_images, vae_stride=vae_stride)
|
||||||
z = self.vace_latent(z0, m0)
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
target_shape = list(z0[0].shape)
|
target_shape = list(z0[0].shape)
|
||||||
@ -616,7 +686,8 @@ class WanVaceMP(WanVace):
|
|||||||
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
||||||
|
|
||||||
# evaluation mode
|
# evaluation mode
|
||||||
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
with amp.autocast(
|
||||||
|
dtype=param_dtype), torch.no_grad(), no_sync():
|
||||||
|
|
||||||
if sample_solver == 'unipc':
|
if sample_solver == 'unipc':
|
||||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||||
@ -631,7 +702,8 @@ class WanVaceMP(WanVace):
|
|||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
shift=1,
|
shift=1,
|
||||||
use_dynamic_shifting=False)
|
use_dynamic_shifting=False)
|
||||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
sampling_sigmas = get_sampling_sigmas(
|
||||||
|
sampling_steps, shift)
|
||||||
timesteps, _ = retrieve_timesteps(
|
timesteps, _ = retrieve_timesteps(
|
||||||
sample_scheduler,
|
sample_scheduler,
|
||||||
device=gpu,
|
device=gpu,
|
||||||
@ -653,10 +725,16 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
model.to(gpu)
|
model.to(gpu)
|
||||||
noise_pred_cond = model(
|
noise_pred_cond = model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
|
latent_model_input,
|
||||||
0]
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
|
**arg_c)[0]
|
||||||
noise_pred_uncond = model(
|
noise_pred_uncond = model(
|
||||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
|
latent_model_input,
|
||||||
|
t=timestep,
|
||||||
|
vace_context=z,
|
||||||
|
vace_context_scale=context_scale,
|
||||||
**arg_null)[0]
|
**arg_null)[0]
|
||||||
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
@ -673,7 +751,8 @@ class WanVaceMP(WanVace):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
videos = self.decode_latent(x0, input_ref_images, vae=vae)
|
videos = self.decode_latent(
|
||||||
|
x0, input_ref_images, vae=vae)
|
||||||
|
|
||||||
del noise, latents
|
del noise, latents
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
@ -691,8 +770,6 @@ class WanVaceMP(WanVace):
|
|||||||
print(trace_info, flush=True)
|
print(trace_info, flush=True)
|
||||||
print(e, flush=True)
|
print(e, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames,
|
input_frames,
|
||||||
@ -709,8 +786,10 @@ class WanVaceMP(WanVace):
|
|||||||
seed=-1,
|
seed=-1,
|
||||||
offload_model=True):
|
offload_model=True):
|
||||||
|
|
||||||
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
|
input_data = (input_prompt, input_frames, input_masks, input_ref_images,
|
||||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
|
size, frame_num, context_scale, shift, sample_solver,
|
||||||
|
sampling_steps, guide_scale, n_prompt, seed,
|
||||||
|
offload_model)
|
||||||
for in_q in self.in_q_list:
|
for in_q in self.in_q_list:
|
||||||
in_q.put(input_data)
|
in_q.put(input_data)
|
||||||
value_output = self.out_q.get()
|
value_output = self.out_q.get()
|
||||||
|
Loading…
Reference in New Issue
Block a user