mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Added Low VRAM support for RTX 10XX and RTX 20XX GPUs
This commit is contained in:
parent
5efddd626d
commit
c62beb7d9d
33
README.md
33
README.md
@ -15,12 +15,13 @@
|
|||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
||||||
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
|
|
||||||
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
|
|
||||||
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
|
|
||||||
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks
|
|
||||||
- Sliding Window generation for Vace, create windows that can last dozen of seconds
|
|
||||||
- A new UI, tabs were replaced by a Dropdown box to easily switch models
|
- A new UI, tabs were replaced by a Dropdown box to easily switch models
|
||||||
|
- A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to *Tophness** for being a big contributor on this new feature
|
||||||
|
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge your video by x2 or x4. Check these new advanced options.
|
||||||
|
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
|
||||||
|
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks used in Vace
|
||||||
|
- Sliding Window generation for Vace, create windows that can last dozen of seconds
|
||||||
|
- New optimisations for old generation GPUs: Generate 5s (81 frames, 15 steps) of Vace 1.3B with only 5GB and in only 6 minutes on a RTX 2080Ti and 5s of t2v 14B in less than 10 minutes.
|
||||||
|
|
||||||
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
|
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
|
||||||
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
|
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
|
||||||
@ -303,6 +304,20 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
|
|||||||
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
||||||
|
|
||||||
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
||||||
|
|
||||||
|
### VACE Slidig Window
|
||||||
|
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
|
||||||
|
|
||||||
|
To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
|
||||||
|
|
||||||
|
Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
|
||||||
|
- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
|
||||||
|
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
|
||||||
|
|
||||||
|
Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
|
||||||
|
|
||||||
|
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
|
||||||
|
|
||||||
### Command line parameters for Gradio Server
|
### Command line parameters for Gradio Server
|
||||||
--i2v : launch the image to video generator\
|
--i2v : launch the image to video generator\
|
||||||
--t2v : launch the text to video generator (default defined in the configuration)\
|
--t2v : launch the text to video generator (default defined in the configuration)\
|
||||||
@ -324,7 +339,7 @@ It seems you will get better results if you turn on "Skip Layer Guidance" with i
|
|||||||
--compile : turn on pytorch compilation\
|
--compile : turn on pytorch compilation\
|
||||||
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
||||||
--profile no : default (4) : no of profile between 1 and 5\
|
--profile no : default (4) : no of profile between 1 and 5\
|
||||||
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
|
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer speed gains on older hardware, on recent hardware (RTX 30XX, RTX40XX and RTX50XX) speed gain is only 10% and not worth it. Works only with profile 2 and 4.\
|
||||||
--seed no : set default seed value\
|
--seed no : set default seed value\
|
||||||
--frames no : set the default number of frames to generate\
|
--frames no : set the default number of frames to generate\
|
||||||
--steps no : set the default number of denoising steps\
|
--steps no : set the default number of denoising steps\
|
||||||
@ -333,7 +348,11 @@ It seems you will get better results if you turn on "Skip Layer Guidance" with i
|
|||||||
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
||||||
--advanced : turn on the advanced mode while launching the app\
|
--advanced : turn on the advanced mode while launching the app\
|
||||||
--listen : make server accessible on network\
|
--listen : make server accessible on network\
|
||||||
--gpu device : run Wan on device for instance "cuda:1"
|
--gpu device : run Wan on device for instance "cuda:1"\
|
||||||
|
--settings: path a folder that contains the default settings for all the models\
|
||||||
|
--fp16: force to use fp16 versions of models instead of bf16 versions\
|
||||||
|
--perc-reserved-mem-max float_less_than_1 : max percentage of RAM to allocate to reserved RAM, allow faster transfers RAM<->VRAM. Value should remain below 0.5 to keep the OS stable\
|
||||||
|
--theme theme_name: load the UI with the specified Theme Name, so far only two are supported, "default" and "gradio". You may submit your own nice looking Gradio theme and I will add them
|
||||||
|
|
||||||
### Profiles (for power users only)
|
### Profiles (for power users only)
|
||||||
You can choose between 5 profiles, but two are really relevant here :
|
You can choose between 5 profiles, but two are really relevant here :
|
||||||
|
|||||||
@ -1,306 +0,0 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import gc
|
|
||||||
import os.path as osp
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# Model
|
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
|
||||||
import wan
|
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
|
||||||
from wan.utils.utils import cache_video
|
|
||||||
|
|
||||||
# Global Var
|
|
||||||
prompt_expander = None
|
|
||||||
wan_i2v_480P = None
|
|
||||||
wan_i2v_720P = None
|
|
||||||
|
|
||||||
|
|
||||||
# Button Func
|
|
||||||
def load_i2v_model(value):
|
|
||||||
global wan_i2v_480P, wan_i2v_720P
|
|
||||||
from mmgp import offload
|
|
||||||
|
|
||||||
if value == '------':
|
|
||||||
print("No model loaded")
|
|
||||||
return '------'
|
|
||||||
|
|
||||||
if value == '720P':
|
|
||||||
if args.ckpt_dir_720p is None:
|
|
||||||
print("Please specify the checkpoint directory for 720P model")
|
|
||||||
return '------'
|
|
||||||
if wan_i2v_720P is not None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
del wan_i2v_480P
|
|
||||||
gc.collect()
|
|
||||||
wan_i2v_480P = None
|
|
||||||
|
|
||||||
print("load 14B-720P i2v model...", end='', flush=True)
|
|
||||||
cfg = WAN_CONFIGS['i2v-14B']
|
|
||||||
wan_i2v_720P = wan.WanI2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=args.ckpt_dir_720p,
|
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
i2v720p= True
|
|
||||||
)
|
|
||||||
print("done", flush=True)
|
|
||||||
pipe = {"transformer": wan_i2v_720P.model, "text_encoder" : wan_i2v_720P.text_encoder.model, "text_encoder_2": wan_i2v_720P.clip.model, "vae": wan_i2v_720P.vae.model } #
|
|
||||||
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, compile="transformer", quantizeTransformer = False, pinnedMemory = False)
|
|
||||||
return '720P'
|
|
||||||
|
|
||||||
if value == '480P':
|
|
||||||
if args.ckpt_dir_480p is None:
|
|
||||||
print("Please specify the checkpoint directory for 480P model")
|
|
||||||
return '------'
|
|
||||||
if wan_i2v_480P is not None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
del wan_i2v_720P
|
|
||||||
gc.collect()
|
|
||||||
wan_i2v_720P = None
|
|
||||||
|
|
||||||
print("load 14B-480P i2v model...", end='', flush=True)
|
|
||||||
cfg = WAN_CONFIGS['i2v-14B']
|
|
||||||
wan_i2v_480P = wan.WanI2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=args.ckpt_dir_480p,
|
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
i2v720p= False
|
|
||||||
)
|
|
||||||
print("done", flush=True)
|
|
||||||
pipe = {"transformer": wan_i2v_480P.model, "text_encoder" : wan_i2v_480P.text_encoder.model, "text_encoder_2": wan_i2v_480P.clip.model, "vae": wan_i2v_480P.vae.model } #
|
|
||||||
offload.profile(pipe, profile_no=4, budgets = {"model":100, "*":3000}, verboseLevel=2, compile="transformer" )
|
|
||||||
|
|
||||||
return '480P'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_enc(prompt, img, tar_lang):
|
|
||||||
print('prompt extend...')
|
|
||||||
if img is None:
|
|
||||||
print('Please upload an image')
|
|
||||||
return prompt
|
|
||||||
global prompt_expander
|
|
||||||
prompt_output = prompt_expander(
|
|
||||||
prompt, image=img, tar_lang=tar_lang.lower())
|
|
||||||
if prompt_output.status == False:
|
|
||||||
return prompt
|
|
||||||
else:
|
|
||||||
return prompt_output.prompt
|
|
||||||
|
|
||||||
|
|
||||||
def i2v_generation(img2vid_prompt, img2vid_image, res, sd_steps,
|
|
||||||
guide_scale, shift_scale, seed, n_prompt):
|
|
||||||
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
|
||||||
global resolution
|
|
||||||
from PIL import Image
|
|
||||||
img2vid_image = Image.open("d:\mammoth2.jpg")
|
|
||||||
if resolution == '------':
|
|
||||||
print(
|
|
||||||
'Please specify at least one resolution ckpt dir or specify the resolution'
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
else:
|
|
||||||
if resolution == '720P':
|
|
||||||
global wan_i2v_720P
|
|
||||||
video = wan_i2v_720P.generate(
|
|
||||||
img2vid_prompt,
|
|
||||||
img2vid_image,
|
|
||||||
max_area=MAX_AREA_CONFIGS['720*1280'],
|
|
||||||
shift=shift_scale,
|
|
||||||
sampling_steps=sd_steps,
|
|
||||||
guide_scale=guide_scale,
|
|
||||||
n_prompt=n_prompt,
|
|
||||||
seed=seed,
|
|
||||||
offload_model=False)
|
|
||||||
else:
|
|
||||||
global wan_i2v_480P
|
|
||||||
video = wan_i2v_480P.generate(
|
|
||||||
img2vid_prompt,
|
|
||||||
img2vid_image,
|
|
||||||
max_area=MAX_AREA_CONFIGS['480*832'],
|
|
||||||
shift=3.0, #shift_scale
|
|
||||||
sampling_steps=sd_steps,
|
|
||||||
guide_scale=guide_scale,
|
|
||||||
n_prompt=n_prompt,
|
|
||||||
seed=seed,
|
|
||||||
offload_model=False)
|
|
||||||
|
|
||||||
cache_video(
|
|
||||||
tensor=video[None],
|
|
||||||
save_file="example.mp4",
|
|
||||||
fps=16,
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
|
||||||
|
|
||||||
return "example.mp4"
|
|
||||||
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
def gradio_interface():
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown("""
|
|
||||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
|
||||||
Wan2.1 (I2V-14B)
|
|
||||||
</div>
|
|
||||||
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
|
|
||||||
Wan: Open and Advanced Large-Scale Video Generative Models.
|
|
||||||
</div>
|
|
||||||
""")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
resolution = gr.Dropdown(
|
|
||||||
label='Resolution',
|
|
||||||
choices=['------', '720P', '480P'],
|
|
||||||
value='------')
|
|
||||||
|
|
||||||
img2vid_image = gr.Image(
|
|
||||||
type="pil",
|
|
||||||
label="Upload Input Image",
|
|
||||||
elem_id="image_upload",
|
|
||||||
)
|
|
||||||
img2vid_prompt = gr.Textbox(
|
|
||||||
label="Prompt",
|
|
||||||
value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.",
|
|
||||||
placeholder="Describe the video you want to generate",
|
|
||||||
)
|
|
||||||
tar_lang = gr.Radio(
|
|
||||||
choices=["CH", "EN"],
|
|
||||||
label="Target language of prompt enhance",
|
|
||||||
value="CH")
|
|
||||||
run_p_button = gr.Button(value="Prompt Enhance")
|
|
||||||
|
|
||||||
with gr.Accordion("Advanced Options", open=True):
|
|
||||||
with gr.Row():
|
|
||||||
sd_steps = gr.Slider(
|
|
||||||
label="Diffusion steps",
|
|
||||||
minimum=1,
|
|
||||||
maximum=1000,
|
|
||||||
value=50,
|
|
||||||
step=1)
|
|
||||||
guide_scale = gr.Slider(
|
|
||||||
label="Guide scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=20,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
with gr.Row():
|
|
||||||
shift_scale = gr.Slider(
|
|
||||||
label="Shift scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=10,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
seed = gr.Slider(
|
|
||||||
label="Seed",
|
|
||||||
minimum=-1,
|
|
||||||
maximum=2147483647,
|
|
||||||
step=1,
|
|
||||||
value=-1)
|
|
||||||
n_prompt = gr.Textbox(
|
|
||||||
label="Negative Prompt",
|
|
||||||
placeholder="Describe the negative prompt you want to add"
|
|
||||||
)
|
|
||||||
|
|
||||||
run_i2v_button = gr.Button("Generate Video")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
result_gallery = gr.Video(
|
|
||||||
label='Generated Video', interactive=False, height=600)
|
|
||||||
|
|
||||||
resolution.input(
|
|
||||||
fn=load_model, inputs=[resolution], outputs=[resolution])
|
|
||||||
|
|
||||||
run_p_button.click(
|
|
||||||
fn=prompt_enc,
|
|
||||||
inputs=[img2vid_prompt, img2vid_image, tar_lang],
|
|
||||||
outputs=[img2vid_prompt])
|
|
||||||
|
|
||||||
run_i2v_button.click(
|
|
||||||
fn=i2v_generation,
|
|
||||||
inputs=[
|
|
||||||
img2vid_prompt, img2vid_image, resolution, sd_steps,
|
|
||||||
guide_scale, shift_scale, seed, n_prompt
|
|
||||||
],
|
|
||||||
outputs=[result_gallery],
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
# Main
|
|
||||||
def _parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate a video from a text prompt or image using Gradio")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir_720p",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir_480p",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_method",
|
|
||||||
type=str,
|
|
||||||
default="local_qwen",
|
|
||||||
choices=["dashscope", "local_qwen"],
|
|
||||||
help="The prompt extend method to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt extend model to use.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
|
|
||||||
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
|
|
||||||
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = _parse_args()
|
|
||||||
global resolution
|
|
||||||
# load_model('720P')
|
|
||||||
# resolution = '720P'
|
|
||||||
resolution = '480P'
|
|
||||||
|
|
||||||
load_i2v_model(resolution)
|
|
||||||
|
|
||||||
print("Step1: Init prompt_expander...", end='', flush=True)
|
|
||||||
if args.prompt_extend_method == "dashscope":
|
|
||||||
prompt_expander = DashScopePromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=True)
|
|
||||||
elif args.prompt_extend_method == "local_qwen":
|
|
||||||
prompt_expander = QwenPromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=True, device=0)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
demo = gradio_interface()
|
|
||||||
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
|
|
||||||
@ -1,206 +0,0 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os.path as osp
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# Model
|
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
|
||||||
import wan
|
|
||||||
from wan.configs import WAN_CONFIGS
|
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
|
||||||
from wan.utils.utils import cache_image
|
|
||||||
|
|
||||||
# Global Var
|
|
||||||
prompt_expander = None
|
|
||||||
wan_t2i = None
|
|
||||||
|
|
||||||
|
|
||||||
# Button Func
|
|
||||||
def prompt_enc(prompt, tar_lang):
|
|
||||||
global prompt_expander
|
|
||||||
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
|
|
||||||
if prompt_output.status == False:
|
|
||||||
return prompt
|
|
||||||
else:
|
|
||||||
return prompt_output.prompt
|
|
||||||
|
|
||||||
|
|
||||||
def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
|
|
||||||
shift_scale, seed, n_prompt):
|
|
||||||
global wan_t2i
|
|
||||||
# print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
|
||||||
|
|
||||||
W = int(resolution.split("*")[0])
|
|
||||||
H = int(resolution.split("*")[1])
|
|
||||||
video = wan_t2i.generate(
|
|
||||||
txt2img_prompt,
|
|
||||||
size=(W, H),
|
|
||||||
frame_num=1,
|
|
||||||
shift=shift_scale,
|
|
||||||
sampling_steps=sd_steps,
|
|
||||||
guide_scale=guide_scale,
|
|
||||||
n_prompt=n_prompt,
|
|
||||||
seed=seed,
|
|
||||||
offload_model=True)
|
|
||||||
|
|
||||||
cache_image(
|
|
||||||
tensor=video.squeeze(1)[None],
|
|
||||||
save_file="example.png",
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
|
||||||
|
|
||||||
return "example.png"
|
|
||||||
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
def gradio_interface():
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown("""
|
|
||||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
|
||||||
Wan2.1 (T2I-14B)
|
|
||||||
</div>
|
|
||||||
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
|
|
||||||
Wan: Open and Advanced Large-Scale Video Generative Models.
|
|
||||||
</div>
|
|
||||||
""")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
txt2img_prompt = gr.Textbox(
|
|
||||||
label="Prompt",
|
|
||||||
placeholder="Describe the image you want to generate",
|
|
||||||
)
|
|
||||||
tar_lang = gr.Radio(
|
|
||||||
choices=["CH", "EN"],
|
|
||||||
label="Target language of prompt enhance",
|
|
||||||
value="CH")
|
|
||||||
run_p_button = gr.Button(value="Prompt Enhance")
|
|
||||||
|
|
||||||
with gr.Accordion("Advanced Options", open=True):
|
|
||||||
resolution = gr.Dropdown(
|
|
||||||
label='Resolution(Width*Height)',
|
|
||||||
choices=[
|
|
||||||
'720*1280', '1280*720', '960*960', '1088*832',
|
|
||||||
'832*1088', '480*832', '832*480', '624*624',
|
|
||||||
'704*544', '544*704'
|
|
||||||
],
|
|
||||||
value='720*1280')
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
sd_steps = gr.Slider(
|
|
||||||
label="Diffusion steps",
|
|
||||||
minimum=1,
|
|
||||||
maximum=1000,
|
|
||||||
value=50,
|
|
||||||
step=1)
|
|
||||||
guide_scale = gr.Slider(
|
|
||||||
label="Guide scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=20,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
with gr.Row():
|
|
||||||
shift_scale = gr.Slider(
|
|
||||||
label="Shift scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=10,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
seed = gr.Slider(
|
|
||||||
label="Seed",
|
|
||||||
minimum=-1,
|
|
||||||
maximum=2147483647,
|
|
||||||
step=1,
|
|
||||||
value=-1)
|
|
||||||
n_prompt = gr.Textbox(
|
|
||||||
label="Negative Prompt",
|
|
||||||
placeholder="Describe the negative prompt you want to add"
|
|
||||||
)
|
|
||||||
|
|
||||||
run_t2i_button = gr.Button("Generate Image")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
result_gallery = gr.Image(
|
|
||||||
label='Generated Image', interactive=False, height=600)
|
|
||||||
|
|
||||||
run_p_button.click(
|
|
||||||
fn=prompt_enc,
|
|
||||||
inputs=[txt2img_prompt, tar_lang],
|
|
||||||
outputs=[txt2img_prompt])
|
|
||||||
|
|
||||||
run_t2i_button.click(
|
|
||||||
fn=t2i_generation,
|
|
||||||
inputs=[
|
|
||||||
txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
|
|
||||||
seed, n_prompt
|
|
||||||
],
|
|
||||||
outputs=[result_gallery],
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
# Main
|
|
||||||
def _parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate a image from a text prompt or image using Gradio")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir",
|
|
||||||
type=str,
|
|
||||||
default="cache",
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_method",
|
|
||||||
type=str,
|
|
||||||
default="local_qwen",
|
|
||||||
choices=["dashscope", "local_qwen"],
|
|
||||||
help="The prompt extend method to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt extend model to use.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = _parse_args()
|
|
||||||
|
|
||||||
print("Step1: Init prompt_expander...", end='', flush=True)
|
|
||||||
if args.prompt_extend_method == "dashscope":
|
|
||||||
prompt_expander = DashScopePromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=False)
|
|
||||||
elif args.prompt_extend_method == "local_qwen":
|
|
||||||
prompt_expander = QwenPromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=False, device=0)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
print("Step2: Init 14B t2i model...", end='', flush=True)
|
|
||||||
cfg = WAN_CONFIGS['t2i-14B']
|
|
||||||
# cfg = WAN_CONFIGS['t2v-1.3B']
|
|
||||||
wan_t2i = wan.WanT2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=args.ckpt_dir,
|
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
)
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
demo = gradio_interface()
|
|
||||||
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
|
|
||||||
@ -1,207 +0,0 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os.path as osp
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# Model
|
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
|
||||||
import wan
|
|
||||||
from wan.configs import WAN_CONFIGS
|
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
|
||||||
from wan.utils.utils import cache_video
|
|
||||||
|
|
||||||
# Global Var
|
|
||||||
prompt_expander = None
|
|
||||||
wan_t2v = None
|
|
||||||
|
|
||||||
|
|
||||||
# Button Func
|
|
||||||
def prompt_enc(prompt, tar_lang):
|
|
||||||
global prompt_expander
|
|
||||||
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
|
|
||||||
if prompt_output.status == False:
|
|
||||||
return prompt
|
|
||||||
else:
|
|
||||||
return prompt_output.prompt
|
|
||||||
|
|
||||||
|
|
||||||
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
|
|
||||||
shift_scale, seed, n_prompt):
|
|
||||||
global wan_t2v
|
|
||||||
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
|
||||||
|
|
||||||
W = int(resolution.split("*")[0])
|
|
||||||
H = int(resolution.split("*")[1])
|
|
||||||
video = wan_t2v.generate(
|
|
||||||
txt2vid_prompt,
|
|
||||||
size=(W, H),
|
|
||||||
shift=shift_scale,
|
|
||||||
sampling_steps=sd_steps,
|
|
||||||
guide_scale=guide_scale,
|
|
||||||
n_prompt=n_prompt,
|
|
||||||
seed=seed,
|
|
||||||
offload_model=True)
|
|
||||||
|
|
||||||
cache_video(
|
|
||||||
tensor=video[None],
|
|
||||||
save_file="example.mp4",
|
|
||||||
fps=16,
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
|
||||||
|
|
||||||
return "example.mp4"
|
|
||||||
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
def gradio_interface():
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown("""
|
|
||||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
|
||||||
Wan2.1 (T2V-1.3B)
|
|
||||||
</div>
|
|
||||||
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
|
|
||||||
Wan: Open and Advanced Large-Scale Video Generative Models.
|
|
||||||
</div>
|
|
||||||
""")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
txt2vid_prompt = gr.Textbox(
|
|
||||||
label="Prompt",
|
|
||||||
placeholder="Describe the video you want to generate",
|
|
||||||
)
|
|
||||||
tar_lang = gr.Radio(
|
|
||||||
choices=["CH", "EN"],
|
|
||||||
label="Target language of prompt enhance",
|
|
||||||
value="CH")
|
|
||||||
run_p_button = gr.Button(value="Prompt Enhance")
|
|
||||||
|
|
||||||
with gr.Accordion("Advanced Options", open=True):
|
|
||||||
resolution = gr.Dropdown(
|
|
||||||
label='Resolution(Width*Height)',
|
|
||||||
choices=[
|
|
||||||
'480*832',
|
|
||||||
'832*480',
|
|
||||||
'624*624',
|
|
||||||
'704*544',
|
|
||||||
'544*704',
|
|
||||||
],
|
|
||||||
value='480*832')
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
sd_steps = gr.Slider(
|
|
||||||
label="Diffusion steps",
|
|
||||||
minimum=1,
|
|
||||||
maximum=1000,
|
|
||||||
value=50,
|
|
||||||
step=1)
|
|
||||||
guide_scale = gr.Slider(
|
|
||||||
label="Guide scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=20,
|
|
||||||
value=6.0,
|
|
||||||
step=1)
|
|
||||||
with gr.Row():
|
|
||||||
shift_scale = gr.Slider(
|
|
||||||
label="Shift scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=20,
|
|
||||||
value=8.0,
|
|
||||||
step=1)
|
|
||||||
seed = gr.Slider(
|
|
||||||
label="Seed",
|
|
||||||
minimum=-1,
|
|
||||||
maximum=2147483647,
|
|
||||||
step=1,
|
|
||||||
value=-1)
|
|
||||||
n_prompt = gr.Textbox(
|
|
||||||
label="Negative Prompt",
|
|
||||||
placeholder="Describe the negative prompt you want to add"
|
|
||||||
)
|
|
||||||
|
|
||||||
run_t2v_button = gr.Button("Generate Video")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
result_gallery = gr.Video(
|
|
||||||
label='Generated Video', interactive=False, height=600)
|
|
||||||
|
|
||||||
run_p_button.click(
|
|
||||||
fn=prompt_enc,
|
|
||||||
inputs=[txt2vid_prompt, tar_lang],
|
|
||||||
outputs=[txt2vid_prompt])
|
|
||||||
|
|
||||||
run_t2v_button.click(
|
|
||||||
fn=t2v_generation,
|
|
||||||
inputs=[
|
|
||||||
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
|
|
||||||
seed, n_prompt
|
|
||||||
],
|
|
||||||
outputs=[result_gallery],
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
# Main
|
|
||||||
def _parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate a video from a text prompt or image using Gradio")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir",
|
|
||||||
type=str,
|
|
||||||
default="cache",
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_method",
|
|
||||||
type=str,
|
|
||||||
default="local_qwen",
|
|
||||||
choices=["dashscope", "local_qwen"],
|
|
||||||
help="The prompt extend method to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt extend model to use.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = _parse_args()
|
|
||||||
|
|
||||||
print("Step1: Init prompt_expander...", end='', flush=True)
|
|
||||||
if args.prompt_extend_method == "dashscope":
|
|
||||||
prompt_expander = DashScopePromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=False)
|
|
||||||
elif args.prompt_extend_method == "local_qwen":
|
|
||||||
prompt_expander = QwenPromptExpander(
|
|
||||||
model_name=args.prompt_extend_model, is_vl=False, device=0)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
print("Step2: Init 1.3B t2v model...", end='', flush=True)
|
|
||||||
cfg = WAN_CONFIGS['t2v-1.3B']
|
|
||||||
wan_t2v = wan.WanT2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=args.ckpt_dir,
|
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
)
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
demo = gradio_interface()
|
|
||||||
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
|
|
||||||
@ -1,216 +0,0 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os.path as osp
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# Model
|
|
||||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
|
||||||
import wan
|
|
||||||
from wan.configs import WAN_CONFIGS
|
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
|
||||||
from wan.utils.utils import cache_video
|
|
||||||
|
|
||||||
# Global Var
|
|
||||||
prompt_expander = None
|
|
||||||
wan_t2v = None
|
|
||||||
|
|
||||||
|
|
||||||
# Button Func
|
|
||||||
def prompt_enc(prompt, tar_lang):
|
|
||||||
global prompt_expander
|
|
||||||
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
|
|
||||||
if prompt_output.status == False:
|
|
||||||
return prompt
|
|
||||||
else:
|
|
||||||
return prompt_output.prompt
|
|
||||||
|
|
||||||
|
|
||||||
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
|
|
||||||
shift_scale, seed, n_prompt):
|
|
||||||
global wan_t2v
|
|
||||||
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
|
||||||
|
|
||||||
W = int(resolution.split("*")[0])
|
|
||||||
H = int(resolution.split("*")[1])
|
|
||||||
video = wan_t2v.generate(
|
|
||||||
txt2vid_prompt,
|
|
||||||
size=(W, H),
|
|
||||||
shift=shift_scale,
|
|
||||||
sampling_steps=sd_steps,
|
|
||||||
guide_scale=guide_scale,
|
|
||||||
n_prompt=n_prompt,
|
|
||||||
seed=seed,
|
|
||||||
offload_model=False)
|
|
||||||
|
|
||||||
cache_video(
|
|
||||||
tensor=video[None],
|
|
||||||
save_file="example.mp4",
|
|
||||||
fps=16,
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
|
||||||
|
|
||||||
return "example.mp4"
|
|
||||||
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
def gradio_interface():
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown("""
|
|
||||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
|
||||||
Wan2.1 (T2V-14B)
|
|
||||||
</div>
|
|
||||||
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
|
|
||||||
Wan: Open and Advanced Large-Scale Video Generative Models.
|
|
||||||
</div>
|
|
||||||
""")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
txt2vid_prompt = gr.Textbox(
|
|
||||||
label="Prompt",
|
|
||||||
placeholder="Describe the video you want to generate",
|
|
||||||
)
|
|
||||||
tar_lang = gr.Radio(
|
|
||||||
choices=["CH", "EN"],
|
|
||||||
label="Target language of prompt enhance",
|
|
||||||
value="CH")
|
|
||||||
run_p_button = gr.Button(value="Prompt Enhance")
|
|
||||||
|
|
||||||
with gr.Accordion("Advanced Options", open=True):
|
|
||||||
resolution = gr.Dropdown(
|
|
||||||
label='Resolution(Width*Height)',
|
|
||||||
choices=[
|
|
||||||
'720*1280', '1280*720', '960*960', '1088*832',
|
|
||||||
'832*1088', '480*832', '832*480', '624*624',
|
|
||||||
'704*544', '544*704'
|
|
||||||
],
|
|
||||||
value='720*1280')
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
sd_steps = gr.Slider(
|
|
||||||
label="Diffusion steps",
|
|
||||||
minimum=1,
|
|
||||||
maximum=1000,
|
|
||||||
value=50,
|
|
||||||
step=1)
|
|
||||||
guide_scale = gr.Slider(
|
|
||||||
label="Guide scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=20,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
with gr.Row():
|
|
||||||
shift_scale = gr.Slider(
|
|
||||||
label="Shift scale",
|
|
||||||
minimum=0,
|
|
||||||
maximum=10,
|
|
||||||
value=5.0,
|
|
||||||
step=1)
|
|
||||||
seed = gr.Slider(
|
|
||||||
label="Seed",
|
|
||||||
minimum=-1,
|
|
||||||
maximum=2147483647,
|
|
||||||
step=1,
|
|
||||||
value=-1)
|
|
||||||
n_prompt = gr.Textbox(
|
|
||||||
label="Negative Prompt",
|
|
||||||
placeholder="Describe the negative prompt you want to add"
|
|
||||||
)
|
|
||||||
|
|
||||||
run_t2v_button = gr.Button("Generate Video")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
result_gallery = gr.Video(
|
|
||||||
label='Generated Video', interactive=False, height=600)
|
|
||||||
|
|
||||||
run_p_button.click(
|
|
||||||
fn=prompt_enc,
|
|
||||||
inputs=[txt2vid_prompt, tar_lang],
|
|
||||||
outputs=[txt2vid_prompt])
|
|
||||||
|
|
||||||
run_t2v_button.click(
|
|
||||||
fn=t2v_generation,
|
|
||||||
inputs=[
|
|
||||||
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
|
|
||||||
seed, n_prompt
|
|
||||||
],
|
|
||||||
outputs=[result_gallery],
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
# Main
|
|
||||||
def _parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate a video from a text prompt or image using Gradio")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir",
|
|
||||||
type=str,
|
|
||||||
default="cache",
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_method",
|
|
||||||
type=str,
|
|
||||||
default="local_qwen",
|
|
||||||
choices=["dashscope", "local_qwen"],
|
|
||||||
help="The prompt extend method to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt extend model to use.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = _parse_args()
|
|
||||||
|
|
||||||
print("Step1: Init prompt_expander...", end='', flush=True)
|
|
||||||
prompt_expander = None
|
|
||||||
# if args.prompt_extend_method == "dashscope":
|
|
||||||
# prompt_expander = DashScopePromptExpander(
|
|
||||||
# model_name=args.prompt_extend_model, is_vl=False)
|
|
||||||
# elif args.prompt_extend_method == "local_qwen":
|
|
||||||
# prompt_expander = QwenPromptExpander(
|
|
||||||
# model_name=args.prompt_extend_model, is_vl=False, device=0)
|
|
||||||
# else:
|
|
||||||
# raise NotImplementedError(
|
|
||||||
# f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
|
||||||
# print("done", flush=True)
|
|
||||||
|
|
||||||
from mmgp import offload
|
|
||||||
|
|
||||||
print("Step2: Init 14B t2v model...", end='', flush=True)
|
|
||||||
cfg = WAN_CONFIGS['t2v-14B']
|
|
||||||
# cfg = WAN_CONFIGS['t2v-1.3B']
|
|
||||||
|
|
||||||
wan_t2v = wan.WanT2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir="../ckpts",
|
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe = {"transformer": wan_t2v.model, "text_encoder" : wan_t2v.text_encoder.model, "vae": wan_t2v.vae.model } #
|
|
||||||
# offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False, compile = "transformer") #
|
|
||||||
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False) #
|
|
||||||
# offload.profile(pipe, profile_no=4, budgets = {"transformer":3000, "*":3000}, verboseLevel=2, quantizeTransformer = False)
|
|
||||||
|
|
||||||
print("done", flush=True)
|
|
||||||
|
|
||||||
demo = gradio_interface()
|
|
||||||
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
|
|
||||||
@ -24,6 +24,7 @@ from .matanyone_wrapper import matanyone
|
|||||||
arg_device = "cuda"
|
arg_device = "cuda"
|
||||||
arg_sam_model_type="vit_h"
|
arg_sam_model_type="vit_h"
|
||||||
arg_mask_save = False
|
arg_mask_save = False
|
||||||
|
model_loaded = False
|
||||||
model = None
|
model = None
|
||||||
matanyone_model = None
|
matanyone_model = None
|
||||||
|
|
||||||
@ -409,36 +410,42 @@ def restart():
|
|||||||
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
|
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
|
||||||
|
|
||||||
def load_unload_models(selected):
|
def load_unload_models(selected):
|
||||||
|
global model_loaded
|
||||||
global model
|
global model
|
||||||
global matanyone_model
|
global matanyone_model
|
||||||
if selected:
|
if selected:
|
||||||
# args, defined in track_anything.py
|
if model_loaded:
|
||||||
sam_checkpoint_url_dict = {
|
model.samcontroler.sam_controler.model.to(arg_device)
|
||||||
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
matanyone_model.to(arg_device)
|
||||||
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
else:
|
||||||
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
# args, defined in track_anything.py
|
||||||
}
|
sam_checkpoint_url_dict = {
|
||||||
# os.path.join('.')
|
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||||
|
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||||||
|
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||||
|
}
|
||||||
|
# os.path.join('.')
|
||||||
|
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
|
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
|
||||||
sam_checkpoint = None
|
sam_checkpoint = None
|
||||||
|
|
||||||
transfer_stream = torch.cuda.Stream()
|
transfer_stream = torch.cuda.Stream()
|
||||||
with torch.cuda.stream(transfer_stream):
|
with torch.cuda.stream(transfer_stream):
|
||||||
# initialize sams
|
# initialize sams
|
||||||
model = MaskGenerator(sam_checkpoint, "cuda")
|
model = MaskGenerator(sam_checkpoint, arg_device)
|
||||||
from .matanyone.model.matanyone import MatAnyone
|
from .matanyone.model.matanyone import MatAnyone
|
||||||
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
|
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
|
||||||
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
|
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
|
||||||
# offload.profile(pipe)
|
# offload.profile(pipe)
|
||||||
matanyone_model = matanyone_model.to(arg_device).eval()
|
matanyone_model = matanyone_model.to(arg_device).eval()
|
||||||
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
||||||
|
model_loaded = True
|
||||||
else:
|
else:
|
||||||
import gc
|
import gc
|
||||||
model = None
|
model.samcontroler.sam_controler.model.to("cpu")
|
||||||
matanyone_model = None
|
matanyone_model.to("cpu")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -451,10 +458,13 @@ def export_to_vace_video_input(foreground_video_output):
|
|||||||
return "V#" + str(time.time()), foreground_video_output
|
return "V#" + str(time.time()), foreground_video_output
|
||||||
|
|
||||||
def export_to_vace_video_mask(foreground_video_output, alpha_video_output):
|
def export_to_vace_video_mask(foreground_video_output, alpha_video_output):
|
||||||
gr.Info("Masked Video Input and Full Mask transferred to Vace For Stronger Inpainting")
|
gr.Info("Masked Video Input and Full Mask transferred to Vace For Inpainting")
|
||||||
return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
||||||
|
|
||||||
def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
|
def teleport_to_vace():
|
||||||
|
return gr.Tabs(selected="video_gen"), gr.Dropdown(value="vace_1.3B")
|
||||||
|
|
||||||
|
def display(tabs, model_choice, vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
|
||||||
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
||||||
|
|
||||||
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
|
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
|
||||||
@ -576,18 +586,23 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
|
|||||||
gr.Markdown("")
|
gr.Markdown("")
|
||||||
|
|
||||||
# output video
|
# output video
|
||||||
with gr.Row(equal_height=True) as output_row:
|
with gr.Column() as output_row: #equal_height=True
|
||||||
with gr.Column(scale=2):
|
with gr.Row():
|
||||||
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
|
with gr.Column(scale=2):
|
||||||
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
|
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
|
||||||
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
|
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
|
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
|
||||||
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
||||||
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting", visible= False)
|
with gr.Row():
|
||||||
|
with gr.Row(visible= False):
|
||||||
|
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
|
||||||
|
with gr.Row(visible= True):
|
||||||
|
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask", visible= False)
|
||||||
|
|
||||||
export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
|
export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
|
||||||
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
|
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask]).then(
|
||||||
|
fn=teleport_to_vace, inputs=[], outputs=[tabs, model_choice])
|
||||||
# first step: get the video information
|
# first step: get the video information
|
||||||
extract_frames_button.click(
|
extract_frames_button.click(
|
||||||
fn=get_frames_from_video,
|
fn=get_frames_from_video,
|
||||||
|
|||||||
@ -16,7 +16,7 @@ gradio>=5.0.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.3.4
|
mmgp==3.4.0
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
mutagen
|
mutagen
|
||||||
decord
|
decord
|
||||||
@ -25,7 +25,6 @@ rembg[gpu]==2.0.65
|
|||||||
matplotlib
|
matplotlib
|
||||||
timm
|
timm
|
||||||
segment-anything
|
segment-anything
|
||||||
ffmpeg-python
|
|
||||||
omegaconf
|
omegaconf
|
||||||
hydra-core
|
hydra-core
|
||||||
# rembg==2.0.65
|
# rembg==2.0.65
|
||||||
@ -48,7 +48,6 @@ class WanI2V:
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
@ -58,6 +57,8 @@ class WanI2V:
|
|||||||
i2v720p= True,
|
i2v720p= True,
|
||||||
model_filename ="",
|
model_filename ="",
|
||||||
text_encoder_filename="",
|
text_encoder_filename="",
|
||||||
|
quantizeTransformer = False,
|
||||||
|
dtype = torch.bfloat16
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the image-to-video generation model components.
|
Initializes the image-to-video generation model components.
|
||||||
@ -82,22 +83,22 @@ class WanI2V:
|
|||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
self.t5_cpu = t5_cpu
|
self.t5_cpu = t5_cpu
|
||||||
|
self.dtype = dtype
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
# shard_fn = partial(shard_model, device_id=device_id)
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=text_encoder_filename,
|
checkpoint_path=text_encoder_filename,
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None,
|
shard_fn=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
@ -116,34 +117,16 @@ class WanI2V:
|
|||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) #forcedConfigPath= "ckpts/config2.json",
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
||||||
|
if self.dtype == torch.float16 and not "fp16" in model_filename:
|
||||||
|
self.model.to(self.dtype)
|
||||||
|
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
|
||||||
|
if self.dtype == torch.float16:
|
||||||
|
self.vae.model.to(self.dtype)
|
||||||
|
|
||||||
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
|
||||||
init_on_cpu = False
|
|
||||||
|
|
||||||
if use_usp:
|
|
||||||
from xfuser.core.distributed import \
|
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
|
||||||
usp_dit_forward)
|
|
||||||
for block in self.model.blocks:
|
|
||||||
block.self_attn.forward = types.MethodType(
|
|
||||||
usp_attn_forward, block.self_attn)
|
|
||||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
|
||||||
else:
|
|
||||||
self.sp_size = 1
|
|
||||||
|
|
||||||
# if dist.is_initialized():
|
|
||||||
# dist.barrier()
|
|
||||||
# if dit_fsdp:
|
|
||||||
# self.model = shard_fn(self.model)
|
|
||||||
# else:
|
|
||||||
# if not init_on_cpu:
|
|
||||||
# self.model.to(self.device)
|
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
@ -229,16 +212,15 @@ class WanI2V:
|
|||||||
w = lat_w * self.vae_stride[2]
|
w = lat_w * self.vae_stride[2]
|
||||||
|
|
||||||
clip_image_size = self.clip.model.image_size
|
clip_image_size = self.clip.model.image_size
|
||||||
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
|
||||||
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
||||||
img = img.sub_(0.5).div_(0.5).to(self.device)
|
img = img.sub_(0.5).div_(0.5).to(self.device, self.dtype)
|
||||||
if img2!= None:
|
if img2!= None:
|
||||||
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
|
||||||
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
||||||
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
|
img2 = img2.sub_(0.5).div_(0.5).to(self.device, self.dtype)
|
||||||
|
|
||||||
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
||||||
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
|
||||||
|
|
||||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
@ -275,6 +257,9 @@ class WanI2V:
|
|||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
|
context = [u.to(self.dtype) for u in context]
|
||||||
|
context_null = [u.to(self.dtype) for u in context_null]
|
||||||
|
|
||||||
clip_context = self.clip.visual([img[:, None, :, :]])
|
clip_context = self.clip.visual([img[:, None, :, :]])
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.clip.model.cpu()
|
self.clip.model.cpu()
|
||||||
@ -285,13 +270,13 @@ class WanI2V:
|
|||||||
mean2 = 0
|
mean2 = 0
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
img_interpolated,
|
img_interpolated,
|
||||||
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
|
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.dtype),
|
||||||
img_interpolated2,
|
img_interpolated2,
|
||||||
], dim=1).to(self.device)
|
], dim=1).to(self.device)
|
||||||
else:
|
else:
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
img_interpolated,
|
img_interpolated,
|
||||||
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
|
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.dtype)
|
||||||
], dim=1).to(self.device)
|
], dim=1).to(self.device)
|
||||||
|
|
||||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||||
@ -447,7 +432,7 @@ class WanI2V:
|
|||||||
callback(i, False)
|
callback(i, False)
|
||||||
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
x0 = [latent.to(self.device, dtype=self.dtype)]
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
|
|||||||
@ -5,6 +5,11 @@ from mmgp import offload
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import memory_efficient_attention
|
||||||
|
except ImportError:
|
||||||
|
memory_efficient_attention = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
FLASH_ATTN_3_AVAILABLE = True
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
@ -123,13 +128,13 @@ def get_attention_modes():
|
|||||||
ret = ["sdpa", "auto"]
|
ret = ["sdpa", "auto"]
|
||||||
if flash_attn != None:
|
if flash_attn != None:
|
||||||
ret.append("flash")
|
ret.append("flash")
|
||||||
# if memory_efficient_attention != None:
|
if memory_efficient_attention != None:
|
||||||
# ret.append("xformers")
|
ret.append("xformers")
|
||||||
if sageattn_varlen_wrapper != None:
|
if sageattn_varlen_wrapper != None:
|
||||||
ret.append("sage")
|
ret.append("sage")
|
||||||
if sageattn != None and version("sageattention").startswith("2") :
|
if sageattn != None and version("sageattention").startswith("2") :
|
||||||
ret.append("sage2")
|
ret.append("sage2")
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_supported_attention_modes():
|
def get_supported_attention_modes():
|
||||||
@ -338,6 +343,14 @@ def pay_attention(
|
|||||||
deterministic=deterministic).unflatten(0, (b, lq))
|
deterministic=deterministic).unflatten(0, (b, lq))
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
|
||||||
|
elif attn=="xformers":
|
||||||
|
x = memory_efficient_attention(
|
||||||
|
q.unsqueeze(0),
|
||||||
|
k.unsqueeze(0),
|
||||||
|
v.unsqueeze(0),
|
||||||
|
) #.unsqueeze(0)
|
||||||
|
|
||||||
return x.type(out_dtype)
|
return x.type(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -77,73 +77,6 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rope_apply_(x, grid_sizes, freqs):
|
|
||||||
assert x.shape[0]==1
|
|
||||||
|
|
||||||
n, c = x.size(2), x.size(3) // 2
|
|
||||||
|
|
||||||
# split freqs
|
|
||||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
|
||||||
|
|
||||||
f, h, w = grid_sizes[0]
|
|
||||||
seq_len = f * h * w
|
|
||||||
x_i = x[0, :seq_len, :, :]
|
|
||||||
|
|
||||||
x_i = x_i.to(torch.float32)
|
|
||||||
x_i = x_i.reshape(seq_len, n, -1, 2)
|
|
||||||
x_i = torch.view_as_complex(x_i)
|
|
||||||
freqs_i = torch.cat([
|
|
||||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1)
|
|
||||||
freqs_i= freqs_i.reshape(seq_len, 1, -1)
|
|
||||||
|
|
||||||
# apply rotary embedding
|
|
||||||
x_i *= freqs_i
|
|
||||||
x_i = torch.view_as_real(x_i).flatten(2)
|
|
||||||
x[0, :seq_len, :, :] = x_i.to(torch.bfloat16)
|
|
||||||
# x_i = torch.cat([x_i, x[0, seq_len:]])
|
|
||||||
return x
|
|
||||||
|
|
||||||
# @amp.autocast(enabled=False)
|
|
||||||
def rope_apply(x, grid_sizes, freqs):
|
|
||||||
n, c = x.size(2), x.size(3) // 2
|
|
||||||
|
|
||||||
# split freqs
|
|
||||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
|
||||||
|
|
||||||
# loop over samples
|
|
||||||
output = []
|
|
||||||
for i, (f, h, w) in enumerate(grid_sizes):
|
|
||||||
seq_len = f * h * w
|
|
||||||
|
|
||||||
# precompute multipliers
|
|
||||||
# x_i = x[i, :seq_len]
|
|
||||||
x_i = x[i]
|
|
||||||
x_i = x_i[:seq_len, :, :]
|
|
||||||
|
|
||||||
x_i = x_i.to(torch.float32)
|
|
||||||
x_i = x_i.reshape(seq_len, n, -1, 2)
|
|
||||||
x_i = torch.view_as_complex(x_i)
|
|
||||||
freqs_i = torch.cat([
|
|
||||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
],
|
|
||||||
dim=-1).reshape(seq_len, 1, -1)
|
|
||||||
|
|
||||||
# apply rotary embedding
|
|
||||||
x_i *= freqs_i
|
|
||||||
x_i = torch.view_as_real(x_i).flatten(2)
|
|
||||||
x_i = x_i.to(torch.bfloat16)
|
|
||||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
|
||||||
|
|
||||||
# append to collection
|
|
||||||
output.append(x_i)
|
|
||||||
return torch.stack(output) #.float()
|
|
||||||
|
|
||||||
def relative_l1_distance(last_tensor, current_tensor):
|
def relative_l1_distance(last_tensor, current_tensor):
|
||||||
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||||
norm = torch.abs(last_tensor).mean()
|
norm = torch.abs(last_tensor).mean()
|
||||||
@ -256,8 +189,6 @@ class WanSelfAttention(nn.Module):
|
|||||||
k = k.view(b, s, n, d)
|
k = k.view(b, s, n, d)
|
||||||
v = self.v(x).view(b, s, n, d)
|
v = self.v(x).view(b, s, n, d)
|
||||||
del x
|
del x
|
||||||
# rope_apply_(q, grid_sizes, freqs)
|
|
||||||
# rope_apply_(k, grid_sizes, freqs)
|
|
||||||
qklist = [q,k]
|
qklist = [q,k]
|
||||||
del q,k
|
del q,k
|
||||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
@ -568,9 +499,9 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
dtype = x.dtype
|
||||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = self.norm(x).to(torch.bfloat16)
|
x = self.norm(x).to(dtype)
|
||||||
x *= (1 + e[1])
|
x *= (1 + e[1])
|
||||||
x += e[0]
|
x += e[0]
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
@ -857,7 +788,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
sinusoidal_embedding_1d(self.freq_dim, t))
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context_lens = None
|
context_lens = None
|
||||||
|
|||||||
@ -51,10 +51,11 @@ class RMS_norm(nn.Module):
|
|||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
dtype = x.dtype
|
||||||
x = F.normalize(
|
x = F.normalize(
|
||||||
x, dim=(1 if self.channel_first else
|
x, dim=(1 if self.channel_first else
|
||||||
-1)) * self.scale * self.gamma + self.bias
|
-1)) * self.scale * self.gamma + self.bias
|
||||||
x = x.to(torch.bfloat16)
|
x = x.to(dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Upsample(nn.Upsample):
|
class Upsample(nn.Upsample):
|
||||||
@ -208,6 +209,7 @@ class ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
h = self.shortcut(x)
|
||||||
|
dtype = x.dtype
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@ -219,11 +221,11 @@ class ResidualBlock(nn.Module):
|
|||||||
cache_x.device), cache_x
|
cache_x.device), cache_x
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
x = layer(x, feat_cache[idx]).to(torch.bfloat16)
|
x = layer(x, feat_cache[idx]).to(dtype)
|
||||||
feat_cache[idx] = cache_x#.to("cpu")
|
feat_cache[idx] = cache_x#.to("cpu")
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x).to(torch.bfloat16)
|
x = layer(x).to(dtype)
|
||||||
return x + h
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
@ -323,6 +325,7 @@ class Encoder3d(nn.Module):
|
|||||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
dtype = x.dtype
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
@ -333,7 +336,7 @@ class Encoder3d(nn.Module):
|
|||||||
cache_x.device), cache_x
|
cache_x.device), cache_x
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
x = self.conv1(x, feat_cache[idx]).to(torch.bfloat16)
|
x = self.conv1(x, feat_cache[idx]).to(dtype)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
del cache_x
|
del cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
|
|||||||
@ -47,14 +47,15 @@ class WanT2V:
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
model_filename = None,
|
model_filename = None,
|
||||||
text_encoder_filename = None
|
text_encoder_filename = None,
|
||||||
|
quantizeTransformer = False,
|
||||||
|
dtype = torch.bfloat16
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the Wan text-to-video generation model components.
|
Initializes the Wan text-to-video generation model components.
|
||||||
@ -77,25 +78,24 @@ class WanT2V:
|
|||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.t5_cpu = t5_cpu
|
self.t5_cpu = t5_cpu
|
||||||
|
self.dtype = dtype
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
dtype=config.t5_dtype,
|
dtype=config.t5_dtype,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
checkpoint_path=text_encoder_filename,
|
checkpoint_path=text_encoder_filename,
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn=shard_fn if t5_fsdp else None)
|
shard_fn= None)
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
|
||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
@ -105,31 +105,14 @@ class WanT2V:
|
|||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
if self.dtype == torch.float16 and not "fp16" in model_filename:
|
||||||
|
self.model.to(self.dtype)
|
||||||
|
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
|
||||||
|
if self.dtype == torch.float16:
|
||||||
|
self.vae.model.to(self.dtype)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if use_usp:
|
|
||||||
from xfuser.core.distributed import \
|
|
||||||
get_sequence_parallel_world_size
|
|
||||||
|
|
||||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
|
||||||
usp_dit_forward)
|
|
||||||
for block in self.model.blocks:
|
|
||||||
block.self_attn.forward = types.MethodType(
|
|
||||||
usp_attn_forward, block.self_attn)
|
|
||||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
|
||||||
else:
|
|
||||||
self.sp_size = 1
|
|
||||||
|
|
||||||
# if dist.is_initialized():
|
|
||||||
# dist.barrier()
|
|
||||||
# if dit_fsdp:
|
|
||||||
# self.model = shard_fn(self.model)
|
|
||||||
# else:
|
|
||||||
# self.model.to(self.device)
|
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
@ -389,8 +372,10 @@ class WanT2V:
|
|||||||
|
|
||||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||||
(self.patch_size[1] * self.patch_size[2]) *
|
(self.patch_size[1] * self.patch_size[2]) *
|
||||||
target_shape[1] / self.sp_size) * self.sp_size
|
target_shape[1])
|
||||||
|
|
||||||
|
context = [u.to(self.dtype) for u in context]
|
||||||
|
context_null = [u.to(self.dtype) for u in context_null]
|
||||||
|
|
||||||
noise = [
|
noise = [
|
||||||
torch.randn(
|
torch.randn(
|
||||||
|
|||||||
227
wgp.py
227
wgp.py
@ -20,7 +20,7 @@ from wan.modules.attention import get_attention_modes, get_supported_attention_m
|
|||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import typing
|
import typing
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
@ -32,6 +32,8 @@ import zipfile
|
|||||||
import tempfile
|
import tempfile
|
||||||
import atexit
|
import atexit
|
||||||
import shutil
|
import shutil
|
||||||
|
import glob
|
||||||
|
|
||||||
global_queue_ref = []
|
global_queue_ref = []
|
||||||
AUTOSAVE_FILENAME = "queue.zip"
|
AUTOSAVE_FILENAME = "queue.zip"
|
||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
@ -203,6 +205,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|||||||
if isinstance(image_refs, list):
|
if isinstance(image_refs, list):
|
||||||
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
||||||
|
|
||||||
|
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||||
from wan.utils.utils import resize_and_remove_background
|
from wan.utils.utils import resize_and_remove_background
|
||||||
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
|
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
|
||||||
|
|
||||||
@ -921,7 +924,7 @@ def autoload_queue(state):
|
|||||||
update_global_queue_ref(original_queue)
|
update_global_queue_ref(original_queue)
|
||||||
dataframe_update = update_queue_data(original_queue)
|
dataframe_update = update_queue_data(original_queue)
|
||||||
else:
|
else:
|
||||||
print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
|
# print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
|
||||||
update_global_queue_ref([])
|
update_global_queue_ref([])
|
||||||
dataframe_update = update_queue_data([])
|
dataframe_update = update_queue_data([])
|
||||||
|
|
||||||
@ -1090,19 +1093,13 @@ def _parse_args():
|
|||||||
help="Lora preset to preload"
|
help="Lora preset to preload"
|
||||||
)
|
)
|
||||||
|
|
||||||
# parser.add_argument(
|
parser.add_argument(
|
||||||
# "--i2v-settings",
|
"--settings",
|
||||||
# type=str,
|
type=str,
|
||||||
# default="i2v_settings.json",
|
default="settings",
|
||||||
# help="Path to settings file for i2v"
|
help="Path to settings folder"
|
||||||
# )
|
)
|
||||||
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--t2v-settings",
|
|
||||||
# type=str,
|
|
||||||
# default="t2v_settings.json",
|
|
||||||
# help="Path to settings file for t2v"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--lora-preset-i2v",
|
# "--lora-preset-i2v",
|
||||||
@ -1152,6 +1149,12 @@ def _parse_args():
|
|||||||
help="Access advanced options by default"
|
help="Access advanced options by default"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
action="store_true",
|
||||||
|
help="For using fp16 transformer model"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--server-port",
|
"--server-port",
|
||||||
type=str,
|
type=str,
|
||||||
@ -1159,6 +1162,22 @@ def _parse_args():
|
|||||||
help="Server port"
|
help="Server port"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--theme",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="set UI Theme"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--perc-reserved-mem-max",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="% of RAM allocated to Reserved RAM"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--server-name",
|
"--server-name",
|
||||||
type=str,
|
type=str,
|
||||||
@ -1307,6 +1326,12 @@ transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "
|
|||||||
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
||||||
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
||||||
server_config_filename = "wgp_config.json"
|
server_config_filename = "wgp_config.json"
|
||||||
|
if not os.path.isdir("settings"):
|
||||||
|
os.mkdir("settings")
|
||||||
|
if os.path.isfile("t2v_settings.json"):
|
||||||
|
for f in glob.glob(os.path.join(".", "*_settings.json*")):
|
||||||
|
target_file = os.path.join("settings", Path(f).parts[-1] )
|
||||||
|
shutil.move(f, target_file)
|
||||||
|
|
||||||
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
|
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
|
||||||
shutil.move("gradio_config.json", server_config_filename)
|
shutil.move("gradio_config.json", server_config_filename)
|
||||||
@ -1321,10 +1346,11 @@ if not Path(server_config_filename).is_file():
|
|||||||
"metadata_type": "metadata",
|
"metadata_type": "metadata",
|
||||||
"default_ui": "t2v",
|
"default_ui": "t2v",
|
||||||
"boost" : 1,
|
"boost" : 1,
|
||||||
"clear_file_list" : 0,
|
"clear_file_list" : 5,
|
||||||
"vae_config": 0,
|
"vae_config": 0,
|
||||||
"profile" : profile_type.LowRAM_LowVRAM,
|
"profile" : profile_type.LowRAM_LowVRAM,
|
||||||
"preload_model_policy": [] }
|
"preload_model_policy": [],
|
||||||
|
"UI_theme": "default" }
|
||||||
|
|
||||||
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
||||||
writer.write(json.dumps(server_config))
|
writer.write(json.dumps(server_config))
|
||||||
@ -1380,7 +1406,7 @@ def get_model_filename(model_type, quantization):
|
|||||||
return choices[0]
|
return choices[0]
|
||||||
|
|
||||||
def get_settings_file_name(model_filename):
|
def get_settings_file_name(model_filename):
|
||||||
return get_model_type(model_filename) + "_settings.json"
|
return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
|
||||||
|
|
||||||
def get_default_settings(filename):
|
def get_default_settings(filename):
|
||||||
def get_default_prompt(i2v):
|
def get_default_prompt(i2v):
|
||||||
@ -1388,11 +1414,11 @@ def get_default_settings(filename):
|
|||||||
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
|
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
|
||||||
else:
|
else:
|
||||||
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
||||||
i2v = "image2video" in file_name
|
i2v = "image2video" in filename or "Fun_InP" in filename
|
||||||
defaults_filename = get_settings_file_name(filename)
|
defaults_filename = get_settings_file_name(filename)
|
||||||
if not Path(defaults_filename).is_file():
|
if not Path(defaults_filename).is_file():
|
||||||
ui_defaults = {
|
ui_defaults = {
|
||||||
"prompts": get_default_prompt(i2v),
|
"prompt": get_default_prompt(i2v),
|
||||||
"resolution": "832x480",
|
"resolution": "832x480",
|
||||||
"video_length": 81,
|
"video_length": 81,
|
||||||
"num_inference_steps": 30,
|
"num_inference_steps": 30,
|
||||||
@ -1651,7 +1677,6 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
|
|||||||
|
|
||||||
|
|
||||||
if lora_dir != None:
|
if lora_dir != None:
|
||||||
import glob
|
|
||||||
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
||||||
dir_loras.sort()
|
dir_loras.sort()
|
||||||
loras += [element for element in dir_loras if element not in loras ]
|
loras += [element for element in dir_loras if element not in loras ]
|
||||||
@ -1676,7 +1701,7 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
|
|||||||
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
||||||
|
|
||||||
|
|
||||||
def load_t2v_model(model_filename, value):
|
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
|
||||||
|
|
||||||
cfg = WAN_CONFIGS['t2v-14B']
|
cfg = WAN_CONFIGS['t2v-14B']
|
||||||
# cfg = WAN_CONFIGS['t2v-1.3B']
|
# cfg = WAN_CONFIGS['t2v-1.3B']
|
||||||
@ -1685,20 +1710,21 @@ def load_t2v_model(model_filename, value):
|
|||||||
wan_model = wan.WanT2V(
|
wan_model = wan.WanT2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir="ckpts",
|
checkpoint_dir="ckpts",
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
model_filename=model_filename,
|
model_filename=model_filename,
|
||||||
text_encoder_filename= text_encoder_filename
|
text_encoder_filename= text_encoder_filename,
|
||||||
|
quantizeTransformer = quantizeTransformer,
|
||||||
|
dtype = dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
||||||
|
|
||||||
return wan_model, pipe
|
return wan_model, pipe
|
||||||
|
|
||||||
def load_i2v_model(model_filename, value):
|
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
|
||||||
|
|
||||||
print(f"Loading '{model_filename}' model...")
|
print(f"Loading '{model_filename}' model...")
|
||||||
|
|
||||||
@ -1707,14 +1733,15 @@ def load_i2v_model(model_filename, value):
|
|||||||
wan_model = wan.WanI2V(
|
wan_model = wan.WanI2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir="ckpts",
|
checkpoint_dir="ckpts",
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
i2v720p= True,
|
i2v720p= True,
|
||||||
model_filename=model_filename,
|
model_filename=model_filename,
|
||||||
text_encoder_filename=text_encoder_filename
|
text_encoder_filename=text_encoder_filename,
|
||||||
|
quantizeTransformer = quantizeTransformer,
|
||||||
|
dtype = dtype
|
||||||
)
|
)
|
||||||
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
||||||
|
|
||||||
@ -1723,15 +1750,15 @@ def load_i2v_model(model_filename, value):
|
|||||||
wan_model = wan.WanI2V(
|
wan_model = wan.WanI2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir="ckpts",
|
checkpoint_dir="ckpts",
|
||||||
device_id=0,
|
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
i2v720p= False,
|
i2v720p= False,
|
||||||
model_filename=model_filename,
|
model_filename=model_filename,
|
||||||
text_encoder_filename=text_encoder_filename
|
text_encoder_filename=text_encoder_filename,
|
||||||
|
quantizeTransformer = quantizeTransformer,
|
||||||
|
dtype = dtype
|
||||||
)
|
)
|
||||||
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
||||||
else:
|
else:
|
||||||
@ -1744,12 +1771,20 @@ def load_models(model_filename):
|
|||||||
global transformer_filename
|
global transformer_filename
|
||||||
|
|
||||||
transformer_filename = model_filename
|
transformer_filename = model_filename
|
||||||
|
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
||||||
|
default_dtype = torch.float16 if major < 8 else torch.bfloat16
|
||||||
|
if default_dtype == torch.float16 or args.fp16:
|
||||||
|
print("Switching to f16 model as GPU architecture doesn't support bf16")
|
||||||
|
if "quanto" in model_filename:
|
||||||
|
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
|
||||||
download_models(model_filename, text_encoder_filename)
|
download_models(model_filename, text_encoder_filename)
|
||||||
if test_class_i2v(model_filename):
|
if test_class_i2v(model_filename):
|
||||||
res720P = "720p" in model_filename
|
res720P = "720p" in model_filename
|
||||||
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P")
|
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype )
|
||||||
else:
|
else:
|
||||||
wan_model, pipe = load_t2v_model(model_filename, "")
|
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype)
|
||||||
wan_model._model_file_name = model_filename
|
wan_model._model_file_name = model_filename
|
||||||
kwargs = { "extraModelsToQuantize": None}
|
kwargs = { "extraModelsToQuantize": None}
|
||||||
if profile == 2 or profile == 4:
|
if profile == 2 or profile == 4:
|
||||||
@ -1758,7 +1793,7 @@ def load_models(model_filename):
|
|||||||
# kwargs["partialPinning"] = True
|
# kwargs["partialPinning"] = True
|
||||||
elif profile == 3:
|
elif profile == 3:
|
||||||
kwargs["budgets"] = { "*" : "70%" }
|
kwargs["budgets"] = { "*" : "70%" }
|
||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, **kwargs)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
|
||||||
if len(args.gpu) > 0:
|
if len(args.gpu) > 0:
|
||||||
torch.set_default_device(args.gpu)
|
torch.set_default_device(args.gpu)
|
||||||
|
|
||||||
@ -1834,6 +1869,7 @@ def apply_changes( state,
|
|||||||
boost_choice = 1,
|
boost_choice = 1,
|
||||||
clear_file_list = 0,
|
clear_file_list = 0,
|
||||||
preload_model_policy_choice = 1,
|
preload_model_policy_choice = 1,
|
||||||
|
UI_theme_choice = "default"
|
||||||
):
|
):
|
||||||
if args.lock_config:
|
if args.lock_config:
|
||||||
return
|
return
|
||||||
@ -1852,6 +1888,7 @@ def apply_changes( state,
|
|||||||
"boost" : boost_choice,
|
"boost" : boost_choice,
|
||||||
"clear_file_list" : clear_file_list,
|
"clear_file_list" : clear_file_list,
|
||||||
"preload_model_policy" : preload_model_policy_choice,
|
"preload_model_policy" : preload_model_policy_choice,
|
||||||
|
"UI_theme" : UI_theme_choice
|
||||||
}
|
}
|
||||||
|
|
||||||
if Path(server_config_filename).is_file():
|
if Path(server_config_filename).is_file():
|
||||||
@ -1874,7 +1911,7 @@ def apply_changes( state,
|
|||||||
if v != v_old:
|
if v != v_old:
|
||||||
changes.append(k)
|
changes.append(k)
|
||||||
|
|
||||||
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
|
global attention_mode, profile, compile, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
|
||||||
attention_mode = server_config["attention_mode"]
|
attention_mode = server_config["attention_mode"]
|
||||||
profile = server_config["profile"]
|
profile = server_config["profile"]
|
||||||
compile = server_config["compile"]
|
compile = server_config["compile"]
|
||||||
@ -1884,10 +1921,13 @@ def apply_changes( state,
|
|||||||
preload_model_policy = server_config["preload_model_policy"]
|
preload_model_policy = server_config["preload_model_policy"]
|
||||||
transformer_quantization = server_config["transformer_quantization"]
|
transformer_quantization = server_config["transformer_quantization"]
|
||||||
transformer_types = server_config["transformer_types"]
|
transformer_types = server_config["transformer_types"]
|
||||||
transformer_type = get_model_type(transformer_filename)
|
model_filename = state["model_filename"]
|
||||||
if not transformer_type in transformer_types:
|
model_transformer_type = get_model_type(model_filename)
|
||||||
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
|
||||||
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
|
if not model_transformer_type in transformer_types:
|
||||||
|
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
||||||
|
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
|
||||||
|
state["model_filename"] = model_filename
|
||||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
||||||
model_choice = gr.Dropdown()
|
model_choice = gr.Dropdown()
|
||||||
else:
|
else:
|
||||||
@ -1990,6 +2030,15 @@ def refresh_gallery(state, msg):
|
|||||||
start_img_md = ""
|
start_img_md = ""
|
||||||
end_img_md = ""
|
end_img_md = ""
|
||||||
prompt = task["prompt"]
|
prompt = task["prompt"]
|
||||||
|
params = task["params"]
|
||||||
|
if "\n" in prompt and params.get("sliding_window_repeat", 0) > 0:
|
||||||
|
prompts = prompt.split("\n")
|
||||||
|
repeat_no= gen.get("repeat_no",1)
|
||||||
|
if repeat_no > len(prompts):
|
||||||
|
repeat_no = len(prompts)
|
||||||
|
repeat_no -= 1
|
||||||
|
prompts[repeat_no]="<B>" + prompts[repeat_no] + "</B>"
|
||||||
|
prompt = "<BR>".join(prompts)
|
||||||
|
|
||||||
start_img_uri = task.get('start_image_data_base64')
|
start_img_uri = task.get('start_image_data_base64')
|
||||||
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
|
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
|
||||||
@ -2463,15 +2512,7 @@ def generate_video(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# with tracker_lock:
|
|
||||||
# progress_tracker[task_id] = {
|
|
||||||
# 'current_step': 0,
|
|
||||||
# 'total_steps': num_inference_steps,
|
|
||||||
# 'start_time': start_time,
|
|
||||||
# 'last_update': start_time,
|
|
||||||
# 'repeats': repeat_generation, # f"{video_no}/{repeat_generation}",
|
|
||||||
# 'status': "Encoding Prompt"
|
|
||||||
# }
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.num_steps = num_inference_steps
|
trans.num_steps = num_inference_steps
|
||||||
@ -2542,20 +2583,17 @@ def generate_video(
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
s = str(e)
|
s = str(e)
|
||||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
keyword_list = {"CUDA out of memory" : "VRAM", "Tried to allocate":"VRAM", "CUDA error: out of memory": "RAM", "CUDA error: too many resources requested": "RAM"}
|
||||||
VRAM_crash= False
|
crash_type = ""
|
||||||
if any( keyword in s for keyword in keyword_list):
|
for keyword, tp in keyword_list.items():
|
||||||
VRAM_crash = True
|
if keyword in s:
|
||||||
else:
|
crash_type = tp
|
||||||
stack = traceback.extract_stack(f=None, limit=5)
|
break
|
||||||
for frame in stack:
|
|
||||||
if any( keyword in frame.name for keyword in keyword_list):
|
|
||||||
VRAM_crash = True
|
|
||||||
break
|
|
||||||
|
|
||||||
state["prompt"] = ""
|
state["prompt"] = ""
|
||||||
if VRAM_crash:
|
if crash_type == "VRAM":
|
||||||
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
|
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
|
||||||
|
elif crash_type == "RAM":
|
||||||
|
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient RAM and / or Reserved RAM allocation should be reduced using 'perc_reserved_mem_max' or using a different Profile."
|
||||||
else:
|
else:
|
||||||
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||||
tb = traceback.format_exc().split('\n')[:-1]
|
tb = traceback.format_exc().split('\n')[:-1]
|
||||||
@ -2929,12 +2967,13 @@ def refresh_lora_list(state, lset_name, loras_choices):
|
|||||||
pos = len(loras_presets)
|
pos = len(loras_presets)
|
||||||
lset_name =""
|
lset_name =""
|
||||||
|
|
||||||
errors = getattr(wan_model.model, "_loras_errors", "")
|
if wan_model != None:
|
||||||
if errors !=None and len(errors) > 0:
|
errors = getattr(wan_model.model, "_loras_errors", "")
|
||||||
error_files = [path for path, _ in errors]
|
if errors !=None and len(errors) > 0:
|
||||||
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
error_files = [path for path, _ in errors]
|
||||||
else:
|
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
||||||
gr.Info("Lora List has been refreshed")
|
else:
|
||||||
|
gr.Info("Lora List has been refreshed")
|
||||||
|
|
||||||
|
|
||||||
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
||||||
@ -3210,7 +3249,7 @@ def save_inputs(
|
|||||||
def download_loras():
|
def download_loras():
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
|
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
|
||||||
lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer)
|
lora_dir = get_lora_dir(get_model_filename("i2v", transformer_quantization))
|
||||||
log_path = os.path.join(lora_dir, "log.txt")
|
log_path = os.path.join(lora_dir, "log.txt")
|
||||||
if not os.path.isfile(log_path):
|
if not os.path.isfile(log_path):
|
||||||
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
|
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
|
||||||
@ -4047,7 +4086,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
outputs=[modal_container]
|
outputs=[modal_container]
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return ( state,
|
||||||
loras_choices, lset_name, state, queue_df, current_gen_column,
|
loras_choices, lset_name, state, queue_df, current_gen_column,
|
||||||
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
||||||
gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
|
gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
|
||||||
@ -4068,9 +4107,7 @@ def generate_download_tab(lset_name,loras_choices, state):
|
|||||||
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||||
|
|
||||||
|
|
||||||
def generate_configuration_tab(header, model_choice):
|
def generate_configuration_tab(state, blocks, header, model_choice):
|
||||||
state_dict = {}
|
|
||||||
state = gr.State(state_dict)
|
|
||||||
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
model_list = []
|
model_list = []
|
||||||
@ -4090,7 +4127,7 @@ def generate_configuration_tab(header, model_choice):
|
|||||||
quantization_choice = gr.Dropdown(
|
quantization_choice = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
("Int8 Quantization (recommended)", "int8"),
|
("Int8 Quantization (recommended)", "int8"),
|
||||||
("BF16 (no quantization)", "bf16"),
|
("16 bits (no quantization)", "bf16"),
|
||||||
],
|
],
|
||||||
value= transformer_quantization,
|
value= transformer_quantization,
|
||||||
label="Wan Transformer Model Quantization Type (if available)",
|
label="Wan Transformer Model Quantization Type (if available)",
|
||||||
@ -4122,7 +4159,7 @@ def generate_configuration_tab(header, model_choice):
|
|||||||
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
|
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
|
||||||
("Scale Dot Product Attention: default, always available", "sdpa"),
|
("Scale Dot Product Attention: default, always available", "sdpa"),
|
||||||
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
||||||
# ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
|
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
|
||||||
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
||||||
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
|
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
|
||||||
],
|
],
|
||||||
@ -4201,10 +4238,19 @@ def generate_configuration_tab(header, model_choice):
|
|||||||
("Keep the last 20 videos", 20),
|
("Keep the last 20 videos", 20),
|
||||||
("Keep the last 30 videos", 30),
|
("Keep the last 30 videos", 30),
|
||||||
],
|
],
|
||||||
value=server_config.get("clear_file_list", 0),
|
value=server_config.get("clear_file_list", 5),
|
||||||
label="Keep Previously Generated Videos when starting a Generation Batch"
|
label="Keep Previously Generated Videos when starting a Generation Batch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
UI_theme_choice = gr.Dropdown(
|
||||||
|
choices=[
|
||||||
|
("Blue Sky", "default"),
|
||||||
|
("Classic Gradio", "gradio"),
|
||||||
|
],
|
||||||
|
value=server_config.get("UI_theme_choice", "default"),
|
||||||
|
label="User Interface Theme. You will need to restart the App the see new Theme."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
msg = gr.Markdown()
|
msg = gr.Markdown()
|
||||||
apply_btn = gr.Button("Apply Changes")
|
apply_btn = gr.Button("Apply Changes")
|
||||||
@ -4224,6 +4270,7 @@ def generate_configuration_tab(header, model_choice):
|
|||||||
boost_choice,
|
boost_choice,
|
||||||
clear_file_list_choice,
|
clear_file_list_choice,
|
||||||
preload_model_policy_choice,
|
preload_model_policy_choice,
|
||||||
|
UI_theme_choice
|
||||||
],
|
],
|
||||||
outputs= [msg , header, model_choice]
|
outputs= [msg , header, model_choice]
|
||||||
)
|
)
|
||||||
@ -4286,20 +4333,15 @@ def select_tab(tab_state, evt:gr.SelectData):
|
|||||||
elif new_tab_no == tab_video_mask_creator:
|
elif new_tab_no == tab_video_mask_creator:
|
||||||
if gen_in_progress:
|
if gen_in_progress:
|
||||||
gr.Info("Unable to access this Tab while a Generation is in Progress. Please come back later")
|
gr.Info("Unable to access this Tab while a Generation is in Progress. Please come back later")
|
||||||
tab_state["tab_auto"]=old_tab_no
|
tab_state["tab_no"] = 0
|
||||||
|
return gr.Tabs(selected="video_gen")
|
||||||
else:
|
else:
|
||||||
vmc_event_handler(True)
|
vmc_event_handler(True)
|
||||||
tab_state["tab_no"] = new_tab_no
|
tab_state["tab_no"] = new_tab_no
|
||||||
def select_tab_auto(tab_state):
|
return gr.Tabs()
|
||||||
old_tab_no = tab_state.pop("tab_auto", -1)
|
|
||||||
if old_tab_no>= 0:
|
|
||||||
tab_state["tab_auto"]=old_tab_no
|
|
||||||
return gr.Tabs(selected=old_tab_no) # !! doesnt work !!
|
|
||||||
return gr.Tab()
|
|
||||||
|
|
||||||
|
|
||||||
def create_demo():
|
def create_demo():
|
||||||
global vmc_event_handler
|
global vmc_event_handler
|
||||||
css = """
|
css = """
|
||||||
#model_list{
|
#model_list{
|
||||||
background-color:black;
|
background-color:black;
|
||||||
@ -4532,14 +4574,21 @@ def create_demo():
|
|||||||
pointer-events: none;
|
pointer-events: none;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
UI_theme = server_config.get("UI_theme", "default")
|
||||||
|
UI_theme = args.theme if len(args.theme) > 0 else UI_theme
|
||||||
|
if UI_theme == "gradio":
|
||||||
|
theme = None
|
||||||
|
else:
|
||||||
|
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
||||||
|
|
||||||
|
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
|
||||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||||
global model_list
|
global model_list
|
||||||
|
|
||||||
tab_state = gr.State({ "tab_no":0 })
|
tab_state = gr.State({ "tab_no":0 })
|
||||||
|
|
||||||
with gr.Tabs(selected="video_gen", ) as main_tabs:
|
with gr.Tabs(selected="video_gen", ) as main_tabs:
|
||||||
with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
|
with gr.Tab("Video Generator", id="video_gen"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if args.lock_model:
|
if args.lock_model:
|
||||||
gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
|
gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
|
||||||
@ -4551,23 +4600,23 @@ def create_demo():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
|
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
(
|
( state,
|
||||||
loras_choices, lset_name, state, queue_df, current_gen_column,
|
loras_choices, lset_name, state, queue_df, current_gen_column,
|
||||||
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
||||||
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
|
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
|
||||||
) = generate_video_tab(model_choice=model_choice, header=header)
|
) = generate_video_tab(model_choice=model_choice, header=header)
|
||||||
with gr.Tab("Informations"):
|
with gr.Tab("Informations", id="info"):
|
||||||
generate_info_tab()
|
generate_info_tab()
|
||||||
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
|
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
|
||||||
from preprocessing.matanyone import app as matanyone_app
|
from preprocessing.matanyone import app as matanyone_app
|
||||||
vmc_event_handler = matanyone_app.get_vmc_event_handler()
|
vmc_event_handler = matanyone_app.get_vmc_event_handler()
|
||||||
|
|
||||||
matanyone_app.display(video_guide, video_mask, video_prompt_type_video_trigger)
|
matanyone_app.display(main_tabs, model_choice, video_guide, video_mask, video_prompt_type_video_trigger)
|
||||||
if not args.lock_config:
|
if not args.lock_config:
|
||||||
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
||||||
generate_download_tab(lset_name, loras_choices, state)
|
generate_download_tab(lset_name, loras_choices, state)
|
||||||
with gr.Tab("Configuration"):
|
with gr.Tab("Configuration", id="configuration"):
|
||||||
generate_configuration_tab(header, model_choice)
|
generate_configuration_tab(state, demo, header, model_choice)
|
||||||
with gr.Tab("About"):
|
with gr.Tab("About"):
|
||||||
generate_about_tab()
|
generate_about_tab()
|
||||||
|
|
||||||
@ -4589,7 +4638,7 @@ def create_demo():
|
|||||||
trigger_mode="always_last"
|
trigger_mode="always_last"
|
||||||
)
|
)
|
||||||
|
|
||||||
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= None).then(fn=select_tab_auto, inputs= [tab_state], outputs=[main_tabs])
|
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs)
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user