Added Low VRAM support for RTX 10XX and RTX 20XX GPUs

This commit is contained in:
DeepBeepMeep 2025-04-15 01:02:06 +02:00
parent 5efddd626d
commit c62beb7d9d
13 changed files with 279 additions and 1215 deletions

View File

@ -15,12 +15,13 @@
## 🔥 Latest News!!
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks
- Sliding Window generation for Vace, create windows that can last dozen of seconds
- A new UI, tabs were replaced by a Dropdown box to easily switch models
- A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to *Tophness** for being a big contributor on this new feature
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge your video by x2 or x4. Check these new advanced options.
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks used in Vace
- Sliding Window generation for Vace, create windows that can last dozen of seconds
- New optimisations for old generation GPUs: Generate 5s (81 frames, 15 steps) of Vace 1.3B with only 5GB and in only 6 minutes on a RTX 2080Ti and 5s of t2v 14B in less than 10 minutes.
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
@ -303,6 +304,20 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
### VACE Slidig Window
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
### Command line parameters for Gradio Server
--i2v : launch the image to video generator\
--t2v : launch the text to video generator (default defined in the configuration)\
@ -324,7 +339,7 @@ It seems you will get better results if you turn on "Skip Layer Guidance" with i
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer speed gains on older hardware, on recent hardware (RTX 30XX, RTX40XX and RTX50XX) speed gain is only 10% and not worth it. Works only with profile 2 and 4.\
--seed no : set default seed value\
--frames no : set the default number of frames to generate\
--steps no : set the default number of denoising steps\
@ -333,7 +348,11 @@ It seems you will get better results if you turn on "Skip Layer Guidance" with i
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
--advanced : turn on the advanced mode while launching the app\
--listen : make server accessible on network\
--gpu device : run Wan on device for instance "cuda:1"
--gpu device : run Wan on device for instance "cuda:1"\
--settings: path a folder that contains the default settings for all the models\
--fp16: force to use fp16 versions of models instead of bf16 versions\
--perc-reserved-mem-max float_less_than_1 : max percentage of RAM to allocate to reserved RAM, allow faster transfers RAM<->VRAM. Value should remain below 0.5 to keep the OS stable\
--theme theme_name: load the UI with the specified Theme Name, so far only two are supported, "default" and "gradio". You may submit your own nice looking Gradio theme and I will add them
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :

View File

@ -1,306 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_i2v_480P = None
wan_i2v_720P = None
# Button Func
def load_i2v_model(value):
global wan_i2v_480P, wan_i2v_720P
from mmgp import offload
if value == '------':
print("No model loaded")
return '------'
if value == '720P':
if args.ckpt_dir_720p is None:
print("Please specify the checkpoint directory for 720P model")
return '------'
if wan_i2v_720P is not None:
pass
else:
del wan_i2v_480P
gc.collect()
wan_i2v_480P = None
print("load 14B-720P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_720P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_720p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= True
)
print("done", flush=True)
pipe = {"transformer": wan_i2v_720P.model, "text_encoder" : wan_i2v_720P.text_encoder.model, "text_encoder_2": wan_i2v_720P.clip.model, "vae": wan_i2v_720P.vae.model } #
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, compile="transformer", quantizeTransformer = False, pinnedMemory = False)
return '720P'
if value == '480P':
if args.ckpt_dir_480p is None:
print("Please specify the checkpoint directory for 480P model")
return '------'
if wan_i2v_480P is not None:
pass
else:
del wan_i2v_720P
gc.collect()
wan_i2v_720P = None
print("load 14B-480P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_480P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_480p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= False
)
print("done", flush=True)
pipe = {"transformer": wan_i2v_480P.model, "text_encoder" : wan_i2v_480P.text_encoder.model, "text_encoder_2": wan_i2v_480P.clip.model, "vae": wan_i2v_480P.vae.model } #
offload.profile(pipe, profile_no=4, budgets = {"model":100, "*":3000}, verboseLevel=2, compile="transformer" )
return '480P'
def prompt_enc(prompt, img, tar_lang):
print('prompt extend...')
if img is None:
print('Please upload an image')
return prompt
global prompt_expander
prompt_output = prompt_expander(
prompt, image=img, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def i2v_generation(img2vid_prompt, img2vid_image, res, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
global resolution
from PIL import Image
img2vid_image = Image.open("d:\mammoth2.jpg")
if resolution == '------':
print(
'Please specify at least one resolution ckpt dir or specify the resolution'
)
return None
else:
if resolution == '720P':
global wan_i2v_720P
video = wan_i2v_720P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['720*1280'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=False)
else:
global wan_i2v_480P
video = wan_i2v_480P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['480*832'],
shift=3.0, #shift_scale
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=False)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (I2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
resolution = gr.Dropdown(
label='Resolution',
choices=['------', '720P', '480P'],
value='------')
img2vid_image = gr.Image(
type="pil",
label="Upload Input Image",
elem_id="image_upload",
)
img2vid_prompt = gr.Textbox(
label="Prompt",
value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_i2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
resolution.input(
fn=load_model, inputs=[resolution], outputs=[resolution])
run_p_button.click(
fn=prompt_enc,
inputs=[img2vid_prompt, img2vid_image, tar_lang],
outputs=[img2vid_prompt])
run_i2v_button.click(
fn=i2v_generation,
inputs=[
img2vid_prompt, img2vid_image, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--ckpt_dir_480p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
return args
if __name__ == '__main__':
args = _parse_args()
global resolution
# load_model('720P')
# resolution = '720P'
resolution = '480P'
load_i2v_model(resolution)
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=True)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=True, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)

View File

@ -1,206 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image
# Global Var
prompt_expander = None
wan_t2i = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2i
# print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2i.generate(
txt2img_prompt,
size=(W, H),
frame_num=1,
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_image(
tensor=video.squeeze(1)[None],
save_file="example.png",
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.png"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2I-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2img_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2i_button = gr.Button("Generate Image")
with gr.Column():
result_gallery = gr.Image(
label='Generated Image', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2img_prompt, tar_lang],
outputs=[txt2img_prompt])
run_t2i_button.click(
fn=t2i_generation,
inputs=[
txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2i model...", end='', flush=True)
cfg = WAN_CONFIGS['t2i-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2i = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)

View File

@ -1,207 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-1.3B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'480*832',
'832*480',
'624*624',
'704*544',
'544*704',
],
value='480*832')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=6.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=20,
value=8.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 1.3B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)

View File

@ -1,216 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=False)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
prompt_expander = None
# if args.prompt_extend_method == "dashscope":
# prompt_expander = DashScopePromptExpander(
# model_name=args.prompt_extend_model, is_vl=False)
# elif args.prompt_extend_method == "local_qwen":
# prompt_expander = QwenPromptExpander(
# model_name=args.prompt_extend_model, is_vl=False, device=0)
# else:
# raise NotImplementedError(
# f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
# print("done", flush=True)
from mmgp import offload
print("Step2: Init 14B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir="../ckpts",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
pipe = {"transformer": wan_t2v.model, "text_encoder" : wan_t2v.text_encoder.model, "vae": wan_t2v.vae.model } #
# offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False, compile = "transformer") #
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False) #
# offload.profile(pipe, profile_no=4, budgets = {"transformer":3000, "*":3000}, verboseLevel=2, quantizeTransformer = False)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)

View File

@ -24,6 +24,7 @@ from .matanyone_wrapper import matanyone
arg_device = "cuda"
arg_sam_model_type="vit_h"
arg_mask_save = False
model_loaded = False
model = None
matanyone_model = None
@ -409,36 +410,42 @@ def restart():
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
def load_unload_models(selected):
global model_loaded
global model
global matanyone_model
if selected:
# args, defined in track_anything.py
sam_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
# os.path.join('.')
if model_loaded:
model.samcontroler.sam_controler.model.to(arg_device)
matanyone_model.to(arg_device)
else:
# args, defined in track_anything.py
sam_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
# os.path.join('.')
from mmgp import offload
from mmgp import offload
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None
transfer_stream = torch.cuda.Stream()
with torch.cuda.stream(transfer_stream):
# initialize sams
model = MaskGenerator(sam_checkpoint, "cuda")
from .matanyone.model.matanyone import MatAnyone
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
# offload.profile(pipe)
matanyone_model = matanyone_model.to(arg_device).eval()
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
transfer_stream = torch.cuda.Stream()
with torch.cuda.stream(transfer_stream):
# initialize sams
model = MaskGenerator(sam_checkpoint, arg_device)
from .matanyone.model.matanyone import MatAnyone
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
# offload.profile(pipe)
matanyone_model = matanyone_model.to(arg_device).eval()
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
model_loaded = True
else:
import gc
model = None
matanyone_model = None
model.samcontroler.sam_controler.model.to("cpu")
matanyone_model.to("cpu")
gc.collect()
torch.cuda.empty_cache()
@ -451,10 +458,13 @@ def export_to_vace_video_input(foreground_video_output):
return "V#" + str(time.time()), foreground_video_output
def export_to_vace_video_mask(foreground_video_output, alpha_video_output):
gr.Info("Masked Video Input and Full Mask transferred to Vace For Stronger Inpainting")
gr.Info("Masked Video Input and Full Mask transferred to Vace For Inpainting")
return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
def teleport_to_vace():
return gr.Tabs(selected="video_gen"), gr.Dropdown(value="vace_1.3B")
def display(tabs, model_choice, vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
@ -576,18 +586,23 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
gr.Markdown("")
# output video
with gr.Row(equal_height=True) as output_row:
with gr.Column(scale=2):
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting", visible= False)
with gr.Column() as output_row: #equal_height=True
with gr.Row():
with gr.Column(scale=2):
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
with gr.Row():
with gr.Row(visible= False):
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
with gr.Row(visible= True):
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask", visible= False)
export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask]).then(
fn=teleport_to_vace, inputs=[], outputs=[tabs, model_choice])
# first step: get the video information
extract_frames_button.click(
fn=get_frames_from_video,

View File

@ -16,7 +16,7 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.3.4
mmgp==3.4.0
peft==0.14.0
mutagen
decord
@ -25,7 +25,6 @@ rembg[gpu]==2.0.65
matplotlib
timm
segment-anything
ffmpeg-python
omegaconf
hydra-core
# rembg==2.0.65

View File

@ -48,7 +48,6 @@ class WanI2V:
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
@ -58,6 +57,8 @@ class WanI2V:
i2v720p= True,
model_filename ="",
text_encoder_filename="",
quantizeTransformer = False,
dtype = torch.bfloat16
):
r"""
Initializes the image-to-video generation model components.
@ -82,22 +83,22 @@ class WanI2V:
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
init_on_cpu (`bool`, *optional*, defaults to True):
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.use_usp = use_usp
self.t5_cpu = t5_cpu
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
# shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
shard_fn=None,
)
self.vae_stride = config.vae_stride
@ -116,34 +117,16 @@ class WanI2V:
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) #forcedConfigPath= "ckpts/config2.json",
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
# if dist.is_initialized():
# dist.barrier()
# if dit_fsdp:
# self.model = shard_fn(self.model)
# else:
# if not init_on_cpu:
# self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
@ -229,16 +212,15 @@ class WanI2V:
w = lat_w * self.vae_stride[2]
clip_image_size = self.clip.model.image_size
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
img = resize_lanczos(img, clip_image_size, clip_image_size)
img = img.sub_(0.5).div_(0.5).to(self.device)
img = img.sub_(0.5).div_(0.5).to(self.device, self.dtype)
if img2!= None:
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
img2 = img2.sub_(0.5).div_(0.5).to(self.device, self.dtype)
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
@ -275,6 +257,9 @@ class WanI2V:
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
@ -285,13 +270,13 @@ class WanI2V:
mean2 = 0
enc= torch.concat([
img_interpolated,
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.dtype),
img_interpolated2,
], dim=1).to(self.device)
else:
enc= torch.concat([
img_interpolated,
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.dtype)
], dim=1).to(self.device)
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
@ -447,7 +432,7 @@ class WanI2V:
callback(i, False)
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
x0 = [latent.to(self.device, dtype=self.dtype)]
if offload_model:
self.model.cpu()

View File

@ -5,6 +5,11 @@ from mmgp import offload
import torch.nn.functional as F
try:
from xformers.ops import memory_efficient_attention
except ImportError:
memory_efficient_attention = None
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@ -123,13 +128,13 @@ def get_attention_modes():
ret = ["sdpa", "auto"]
if flash_attn != None:
ret.append("flash")
# if memory_efficient_attention != None:
# ret.append("xformers")
if memory_efficient_attention != None:
ret.append("xformers")
if sageattn_varlen_wrapper != None:
ret.append("sage")
if sageattn != None and version("sageattention").startswith("2") :
ret.append("sage2")
return ret
def get_supported_attention_modes():
@ -338,6 +343,14 @@ def pay_attention(
deterministic=deterministic).unflatten(0, (b, lq))
# output
elif attn=="xformers":
x = memory_efficient_attention(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
) #.unsqueeze(0)
return x.type(out_dtype)

View File

@ -77,73 +77,6 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
def rope_apply_(x, grid_sizes, freqs):
assert x.shape[0]==1
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
seq_len = f * h * w
x_i = x[0, :seq_len, :, :]
x_i = x_i.to(torch.float32)
x_i = x_i.reshape(seq_len, n, -1, 2)
x_i = torch.view_as_complex(x_i)
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1)
freqs_i= freqs_i.reshape(seq_len, 1, -1)
# apply rotary embedding
x_i *= freqs_i
x_i = torch.view_as_real(x_i).flatten(2)
x[0, :seq_len, :, :] = x_i.to(torch.bfloat16)
# x_i = torch.cat([x_i, x[0, seq_len:]])
return x
# @amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
# precompute multipliers
# x_i = x[i, :seq_len]
x_i = x[i]
x_i = x_i[:seq_len, :, :]
x_i = x_i.to(torch.float32)
x_i = x_i.reshape(seq_len, n, -1, 2)
x_i = torch.view_as_complex(x_i)
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i *= freqs_i
x_i = torch.view_as_real(x_i).flatten(2)
x_i = x_i.to(torch.bfloat16)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output) #.float()
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor - current_tensor).mean()
norm = torch.abs(last_tensor).mean()
@ -256,8 +189,6 @@ class WanSelfAttention(nn.Module):
k = k.view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
del x
# rope_apply_(q, grid_sizes, freqs)
# rope_apply_(k, grid_sizes, freqs)
qklist = [q,k]
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
@ -568,9 +499,9 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
dtype = x.dtype
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = self.norm(x).to(torch.bfloat16)
x = self.norm(x).to(dtype)
x *= (1 + e[1])
x += e[0]
x = self.head(x)
@ -857,7 +788,7 @@ class WanModel(ModelMixin, ConfigMixin):
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
# context
context_lens = None

View File

@ -51,10 +51,11 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
dtype = x.dtype
x = F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
x = x.to(torch.bfloat16)
x = x.to(dtype)
return x
class Upsample(nn.Upsample):
@ -208,6 +209,7 @@ class ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
dtype = x.dtype
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@ -219,11 +221,11 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]).to(torch.bfloat16)
x = layer(x, feat_cache[idx]).to(dtype)
feat_cache[idx] = cache_x#.to("cpu")
feat_idx[0] += 1
else:
x = layer(x).to(torch.bfloat16)
x = layer(x).to(dtype)
return x + h
@ -323,6 +325,7 @@ class Encoder3d(nn.Module):
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
dtype = x.dtype
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
@ -333,7 +336,7 @@ class Encoder3d(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]).to(torch.bfloat16)
x = self.conv1(x, feat_cache[idx]).to(dtype)
feat_cache[idx] = cache_x
del cache_x
feat_idx[0] += 1

View File

@ -47,14 +47,15 @@ class WanT2V:
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
model_filename = None,
text_encoder_filename = None
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
):
r"""
Initializes the Wan text-to-video generation model components.
@ -77,25 +78,24 @@ class WanT2V:
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
shard_fn= None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.patch_size = config.patch_size
self.vae = WanVAE(
@ -105,31 +105,14 @@ class WanT2V:
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
# if dist.is_initialized():
# dist.barrier()
# if dit_fsdp:
# self.model = shard_fn(self.model)
# else:
# self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
@ -389,8 +372,10 @@ class WanT2V:
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
target_shape[1])
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
noise = [
torch.randn(

227
wgp.py
View File

@ -20,7 +20,7 @@ from wan.modules.attention import get_attention_modes, get_supported_attention_m
import torch
import gc
import traceback
import math
import math
import typing
import asyncio
import inspect
@ -32,6 +32,8 @@ import zipfile
import tempfile
import atexit
import shutil
import glob
global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
@ -203,6 +205,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if isinstance(image_refs, list):
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from wan.utils.utils import resize_and_remove_background
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
@ -921,7 +924,7 @@ def autoload_queue(state):
update_global_queue_ref(original_queue)
dataframe_update = update_queue_data(original_queue)
else:
print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
# print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
update_global_queue_ref([])
dataframe_update = update_queue_data([])
@ -1090,19 +1093,13 @@ def _parse_args():
help="Lora preset to preload"
)
# parser.add_argument(
# "--i2v-settings",
# type=str,
# default="i2v_settings.json",
# help="Path to settings file for i2v"
# )
parser.add_argument(
"--settings",
type=str,
default="settings",
help="Path to settings folder"
)
# parser.add_argument(
# "--t2v-settings",
# type=str,
# default="t2v_settings.json",
# help="Path to settings file for t2v"
# )
# parser.add_argument(
# "--lora-preset-i2v",
@ -1152,6 +1149,12 @@ def _parse_args():
help="Access advanced options by default"
)
parser.add_argument(
"--fp16",
action="store_true",
help="For using fp16 transformer model"
)
parser.add_argument(
"--server-port",
type=str,
@ -1159,6 +1162,22 @@ def _parse_args():
help="Server port"
)
parser.add_argument(
"--theme",
type=str,
default="",
help="set UI Theme"
)
parser.add_argument(
"--perc-reserved-mem-max",
type=float,
default=0,
help="% of RAM allocated to Reserved RAM"
)
parser.add_argument(
"--server-name",
type=str,
@ -1307,6 +1326,12 @@ transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
server_config_filename = "wgp_config.json"
if not os.path.isdir("settings"):
os.mkdir("settings")
if os.path.isfile("t2v_settings.json"):
for f in glob.glob(os.path.join(".", "*_settings.json*")):
target_file = os.path.join("settings", Path(f).parts[-1] )
shutil.move(f, target_file)
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
shutil.move("gradio_config.json", server_config_filename)
@ -1321,10 +1346,11 @@ if not Path(server_config_filename).is_file():
"metadata_type": "metadata",
"default_ui": "t2v",
"boost" : 1,
"clear_file_list" : 0,
"clear_file_list" : 5,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM,
"preload_model_policy": [] }
"preload_model_policy": [],
"UI_theme": "default" }
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config))
@ -1380,7 +1406,7 @@ def get_model_filename(model_type, quantization):
return choices[0]
def get_settings_file_name(model_filename):
return get_model_type(model_filename) + "_settings.json"
return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
def get_default_settings(filename):
def get_default_prompt(i2v):
@ -1388,11 +1414,11 @@ def get_default_settings(filename):
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
else:
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
i2v = "image2video" in file_name
i2v = "image2video" in filename or "Fun_InP" in filename
defaults_filename = get_settings_file_name(filename)
if not Path(defaults_filename).is_file():
ui_defaults = {
"prompts": get_default_prompt(i2v),
"prompt": get_default_prompt(i2v),
"resolution": "832x480",
"video_length": 81,
"num_inference_steps": 30,
@ -1651,7 +1677,6 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
if lora_dir != None:
import glob
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
dir_loras.sort()
loras += [element for element in dir_loras if element not in loras ]
@ -1676,7 +1701,7 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def load_t2v_model(model_filename, value):
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
cfg = WAN_CONFIGS['t2v-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
@ -1685,20 +1710,21 @@ def load_t2v_model(model_filename, value):
wan_model = wan.WanT2V(
config=cfg,
checkpoint_dir="ckpts",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
model_filename=model_filename,
text_encoder_filename= text_encoder_filename
text_encoder_filename= text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
return wan_model, pipe
def load_i2v_model(model_filename, value):
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
print(f"Loading '{model_filename}' model...")
@ -1707,14 +1733,15 @@ def load_i2v_model(model_filename, value):
wan_model = wan.WanI2V(
config=cfg,
checkpoint_dir="ckpts",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= True,
model_filename=model_filename,
text_encoder_filename=text_encoder_filename
text_encoder_filename=text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
@ -1723,15 +1750,15 @@ def load_i2v_model(model_filename, value):
wan_model = wan.WanI2V(
config=cfg,
checkpoint_dir="ckpts",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= False,
model_filename=model_filename,
text_encoder_filename=text_encoder_filename
text_encoder_filename=text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
else:
@ -1744,12 +1771,20 @@ def load_models(model_filename):
global transformer_filename
transformer_filename = model_filename
perc_reserved_mem_max = args.perc_reserved_mem_max
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
default_dtype = torch.float16 if major < 8 else torch.bfloat16
if default_dtype == torch.float16 or args.fp16:
print("Switching to f16 model as GPU architecture doesn't support bf16")
if "quanto" in model_filename:
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
download_models(model_filename, text_encoder_filename)
if test_class_i2v(model_filename):
res720P = "720p" in model_filename
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P")
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype )
else:
wan_model, pipe = load_t2v_model(model_filename, "")
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype)
wan_model._model_file_name = model_filename
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
@ -1758,7 +1793,7 @@ def load_models(model_filename):
# kwargs["partialPinning"] = True
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, **kwargs)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
@ -1834,6 +1869,7 @@ def apply_changes( state,
boost_choice = 1,
clear_file_list = 0,
preload_model_policy_choice = 1,
UI_theme_choice = "default"
):
if args.lock_config:
return
@ -1852,6 +1888,7 @@ def apply_changes( state,
"boost" : boost_choice,
"clear_file_list" : clear_file_list,
"preload_model_policy" : preload_model_policy_choice,
"UI_theme" : UI_theme_choice
}
if Path(server_config_filename).is_file():
@ -1874,7 +1911,7 @@ def apply_changes( state,
if v != v_old:
changes.append(k)
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
global attention_mode, profile, compile, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
attention_mode = server_config["attention_mode"]
profile = server_config["profile"]
compile = server_config["compile"]
@ -1884,10 +1921,13 @@ def apply_changes( state,
preload_model_policy = server_config["preload_model_policy"]
transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"]
transformer_type = get_model_type(transformer_filename)
if not transformer_type in transformer_types:
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
model_filename = state["model_filename"]
model_transformer_type = get_model_type(model_filename)
if not model_transformer_type in transformer_types:
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
model_choice = gr.Dropdown()
else:
@ -1990,6 +2030,15 @@ def refresh_gallery(state, msg):
start_img_md = ""
end_img_md = ""
prompt = task["prompt"]
params = task["params"]
if "\n" in prompt and params.get("sliding_window_repeat", 0) > 0:
prompts = prompt.split("\n")
repeat_no= gen.get("repeat_no",1)
if repeat_no > len(prompts):
repeat_no = len(prompts)
repeat_no -= 1
prompts[repeat_no]="<B>" + prompts[repeat_no] + "</B>"
prompt = "<BR>".join(prompts)
start_img_uri = task.get('start_image_data_base64')
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
@ -2463,15 +2512,7 @@ def generate_video(
try:
start_time = time.time()
# with tracker_lock:
# progress_tracker[task_id] = {
# 'current_step': 0,
# 'total_steps': num_inference_steps,
# 'start_time': start_time,
# 'last_update': start_time,
# 'repeats': repeat_generation, # f"{video_no}/{repeat_generation}",
# 'status': "Encoding Prompt"
# }
if trans.enable_teacache:
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
@ -2542,20 +2583,17 @@ def generate_video(
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
VRAM_crash = True
else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
keyword_list = {"CUDA out of memory" : "VRAM", "Tried to allocate":"VRAM", "CUDA error: out of memory": "RAM", "CUDA error: too many resources requested": "RAM"}
crash_type = ""
for keyword, tp in keyword_list.items():
if keyword in s:
crash_type = tp
break
state["prompt"] = ""
if VRAM_crash:
if crash_type == "VRAM":
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
elif crash_type == "RAM":
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient RAM and / or Reserved RAM allocation should be reduced using 'perc_reserved_mem_max' or using a different Profile."
else:
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
tb = traceback.format_exc().split('\n')[:-1]
@ -2929,12 +2967,13 @@ def refresh_lora_list(state, lset_name, loras_choices):
pos = len(loras_presets)
lset_name =""
errors = getattr(wan_model.model, "_loras_errors", "")
if errors !=None and len(errors) > 0:
error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else:
gr.Info("Lora List has been refreshed")
if wan_model != None:
errors = getattr(wan_model.model, "_loras_errors", "")
if errors !=None and len(errors) > 0:
error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else:
gr.Info("Lora List has been refreshed")
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
@ -3210,7 +3249,7 @@ def save_inputs(
def download_loras():
from huggingface_hub import snapshot_download
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer)
lora_dir = get_lora_dir(get_model_filename("i2v", transformer_quantization))
log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path):
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
@ -4047,7 +4086,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs=[modal_container]
)
return (
return ( state,
loras_choices, lset_name, state, queue_df, current_gen_column,
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
@ -4068,9 +4107,7 @@ def generate_download_tab(lset_name,loras_choices, state):
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
def generate_configuration_tab(header, model_choice):
state_dict = {}
state = gr.State(state_dict)
def generate_configuration_tab(state, blocks, header, model_choice):
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
with gr.Column():
model_list = []
@ -4090,7 +4127,7 @@ def generate_configuration_tab(header, model_choice):
quantization_choice = gr.Dropdown(
choices=[
("Int8 Quantization (recommended)", "int8"),
("BF16 (no quantization)", "bf16"),
("16 bits (no quantization)", "bf16"),
],
value= transformer_quantization,
label="Wan Transformer Model Quantization Type (if available)",
@ -4122,7 +4159,7 @@ def generate_configuration_tab(header, model_choice):
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
("Scale Dot Product Attention: default, always available", "sdpa"),
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
# ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
],
@ -4201,10 +4238,19 @@ def generate_configuration_tab(header, model_choice):
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 0),
value=server_config.get("clear_file_list", 5),
label="Keep Previously Generated Videos when starting a Generation Batch"
)
UI_theme_choice = gr.Dropdown(
choices=[
("Blue Sky", "default"),
("Classic Gradio", "gradio"),
],
value=server_config.get("UI_theme_choice", "default"),
label="User Interface Theme. You will need to restart the App the see new Theme."
)
msg = gr.Markdown()
apply_btn = gr.Button("Apply Changes")
@ -4224,6 +4270,7 @@ def generate_configuration_tab(header, model_choice):
boost_choice,
clear_file_list_choice,
preload_model_policy_choice,
UI_theme_choice
],
outputs= [msg , header, model_choice]
)
@ -4286,20 +4333,15 @@ def select_tab(tab_state, evt:gr.SelectData):
elif new_tab_no == tab_video_mask_creator:
if gen_in_progress:
gr.Info("Unable to access this Tab while a Generation is in Progress. Please come back later")
tab_state["tab_auto"]=old_tab_no
tab_state["tab_no"] = 0
return gr.Tabs(selected="video_gen")
else:
vmc_event_handler(True)
tab_state["tab_no"] = new_tab_no
def select_tab_auto(tab_state):
old_tab_no = tab_state.pop("tab_auto", -1)
if old_tab_no>= 0:
tab_state["tab_auto"]=old_tab_no
return gr.Tabs(selected=old_tab_no) # !! doesnt work !!
return gr.Tab()
return gr.Tabs()
def create_demo():
global vmc_event_handler
global vmc_event_handler
css = """
#model_list{
background-color:black;
@ -4532,14 +4574,21 @@ def create_demo():
pointer-events: none;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
UI_theme = server_config.get("UI_theme", "default")
UI_theme = args.theme if len(args.theme) > 0 else UI_theme
if UI_theme == "gradio":
theme = None
else:
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })
with gr.Tabs(selected="video_gen", ) as main_tabs:
with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
with gr.Tab("Video Generator", id="video_gen"):
with gr.Row():
if args.lock_model:
gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
@ -4551,23 +4600,23 @@ def create_demo():
with gr.Row():
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
with gr.Row():
(
( state,
loras_choices, lset_name, state, queue_df, current_gen_column,
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
) = generate_video_tab(model_choice=model_choice, header=header)
with gr.Tab("Informations"):
with gr.Tab("Informations", id="info"):
generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
from preprocessing.matanyone import app as matanyone_app
vmc_event_handler = matanyone_app.get_vmc_event_handler()
matanyone_app.display(video_guide, video_mask, video_prompt_type_video_trigger)
matanyone_app.display(main_tabs, model_choice, video_guide, video_mask, video_prompt_type_video_trigger)
if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_download_tab(lset_name, loras_choices, state)
with gr.Tab("Configuration"):
generate_configuration_tab(header, model_choice)
with gr.Tab("Configuration", id="configuration"):
generate_configuration_tab(state, demo, header, model_choice)
with gr.Tab("About"):
generate_about_tab()
@ -4589,7 +4638,7 @@ def create_demo():
trigger_mode="always_last"
)
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= None).then(fn=select_tab_auto, inputs= [tab_state], outputs=[main_tabs])
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs)
return demo
if __name__ == "__main__":