mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
merge with latest branch
This commit is contained in:
parent
c7a8c82398
commit
7e7c9c5593
@ -19,11 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
|
||||||
|
You will need one more *pip install -r requirements.txt*
|
||||||
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
||||||
- New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app
|
- New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app
|
||||||
- Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
|
- Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
|
||||||
- You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
|
- You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
|
||||||
- Slight acceleration with loras
|
- Slight acceleration with loras\
|
||||||
You will need one more *pip install -r requirements.txt*
|
You will need one more *pip install -r requirements.txt*
|
||||||
Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
|
Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
|
||||||
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\
|
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\
|
||||||
@ -255,8 +257,9 @@ You can define multiple lines of macros. If there is only one macro line, the ap
|
|||||||
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
||||||
--advanced : turn on the advanced mode while launching the app\
|
--advanced : turn on the advanced mode while launching the app\
|
||||||
--i2v-settings : path to launch settings for i2v\
|
--i2v-settings : path to launch settings for i2v\
|
||||||
--t2v-settings : path to launch settings for t2v
|
--t2v-settings : path to launch settings for t2v\
|
||||||
--listen : make server accessible on network
|
--listen : make server accessible on network\
|
||||||
|
--gpu device : run Wan on device for instance "cuda:1"
|
||||||
|
|
||||||
### Profiles (for power users only)
|
### Profiles (for power users only)
|
||||||
You can choose between 5 profiles, but two are really relevant here :
|
You can choose between 5 profiles, but two are really relevant here :
|
||||||
|
|||||||
110
gradio_server.py
110
gradio_server.py
@ -15,7 +15,7 @@ import json
|
|||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
|
||||||
from wan.utils.utils import cache_video
|
from wan.utils.utils import cache_video
|
||||||
from wan.modules.attention import get_attention_modes
|
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
|
||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
import traceback
|
import traceback
|
||||||
@ -24,7 +24,7 @@ import asyncio
|
|||||||
from wan.utils import prompt_parser
|
from wan.utils import prompt_parser
|
||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
target_mmgp_version = "3.3.1"
|
target_mmgp_version = "3.3.3"
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
mmgp_version = version("mmgp")
|
mmgp_version = version("mmgp")
|
||||||
if mmgp_version != target_mmgp_version:
|
if mmgp_version != target_mmgp_version:
|
||||||
@ -55,16 +55,16 @@ def runner():
|
|||||||
while True:
|
while True:
|
||||||
with lock:
|
with lock:
|
||||||
for item in queue:
|
for item in queue:
|
||||||
task_id = item['id']
|
task_id_runner = item['id']
|
||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
progress = progress_tracker.get(task_id, {})
|
progress = progress_tracker.get(task_id_runner, {})
|
||||||
|
|
||||||
if item['state'] == "Processing":
|
if item['state'] == "Processing":
|
||||||
current_step = progress.get('current_step', 0)
|
current_step = progress.get('current_step', 0)
|
||||||
total_steps = progress.get('total_steps', 0)
|
total_steps = progress.get('total_steps', 0)
|
||||||
elapsed = time.time() - progress.get('start_time', time.time())
|
elapsed = time.time() - progress.get('start_time', time.time())
|
||||||
status = progress.get('status', "")
|
status = progress.get('status', "")
|
||||||
repeats = progress.get("repeats")
|
repeats = progress.get("repeats", "0/0")
|
||||||
item.update({
|
item.update({
|
||||||
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
||||||
'steps': f"{current_step}/{total_steps}",
|
'steps': f"{current_step}/{total_steps}",
|
||||||
@ -381,6 +381,13 @@ def _parse_args():
|
|||||||
help="Server name"
|
help="Server name"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Default GPU Device"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--open-browser",
|
"--open-browser",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -473,7 +480,8 @@ def get_lora_dir(i2v):
|
|||||||
return lora_dir_14B
|
return lora_dir_14B
|
||||||
return root_lora_dir
|
return root_lora_dir
|
||||||
|
|
||||||
attention_modes_supported = get_attention_modes()
|
attention_modes_installed = get_attention_modes()
|
||||||
|
attention_modes_supported = get_supported_attention_modes()
|
||||||
|
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
args.flow_reverse = True
|
args.flow_reverse = True
|
||||||
@ -587,6 +595,7 @@ vae_config = server_config.get("vae_config", 0)
|
|||||||
if len(args.vae_config) > 0:
|
if len(args.vae_config) > 0:
|
||||||
vae_config = int(args.vae_config)
|
vae_config = int(args.vae_config)
|
||||||
|
|
||||||
|
reload_needed = False
|
||||||
default_ui = server_config.get("default_ui", "t2v")
|
default_ui = server_config.get("default_ui", "t2v")
|
||||||
metadata = server_config.get("metadata_type", "metadata")
|
metadata = server_config.get("metadata_type", "metadata")
|
||||||
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
||||||
@ -686,7 +695,7 @@ def download_models(transformer_filename, text_encoder_filename):
|
|||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
repoId = "DeepBeepMeep/Wan2.1"
|
repoId = "DeepBeepMeep/Wan2.1"
|
||||||
sourceFolderList = ["xlm-roberta-large", "", ]
|
sourceFolderList = ["xlm-roberta-large", "", ]
|
||||||
fileList = [ [], ["Wan2.1_VAE.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
||||||
targetRoot = "ckpts/"
|
targetRoot = "ckpts/"
|
||||||
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
||||||
if len(files)==0:
|
if len(files)==0:
|
||||||
@ -703,6 +712,14 @@ def download_models(transformer_filename, text_encoder_filename):
|
|||||||
|
|
||||||
|
|
||||||
offload.default_verboseLevel = verbose_level
|
offload.default_verboseLevel = verbose_level
|
||||||
|
to_remove = ["models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", "Wan2.1_VAE.pth"]
|
||||||
|
for file_name in to_remove:
|
||||||
|
file_name = os.path.join("ckpts",file_name)
|
||||||
|
if os.path.isfile(file_name):
|
||||||
|
try:
|
||||||
|
os.remove(file_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename)
|
download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename)
|
||||||
|
|
||||||
@ -875,6 +892,8 @@ def load_models(i2v):
|
|||||||
elif profile == 3:
|
elif profile == 3:
|
||||||
kwargs["budgets"] = { "*" : "70%" }
|
kwargs["budgets"] = { "*" : "70%" }
|
||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
|
||||||
|
if len(args.gpu) > 0:
|
||||||
|
torch.set_default_device(args.gpu)
|
||||||
|
|
||||||
return wan_model, offloadobj, pipe["transformer"]
|
return wan_model, offloadobj, pipe["transformer"]
|
||||||
|
|
||||||
@ -914,8 +933,10 @@ def generate_header(model_filename, compile, attention_mode):
|
|||||||
|
|
||||||
header += model_name
|
header += model_name
|
||||||
header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
||||||
if attention_mode not in attention_modes_supported:
|
if attention_mode not in attention_modes_installed:
|
||||||
header += " -NOT INSTALLED-"
|
header += " -NOT INSTALLED-"
|
||||||
|
elif attention_mode not in attention_modes_supported:
|
||||||
|
header += " -NOT SUPPORTED-"
|
||||||
|
|
||||||
if compile:
|
if compile:
|
||||||
header += ", pytorch compilation ON"
|
header += ", pytorch compilation ON"
|
||||||
@ -979,11 +1000,7 @@ def apply_changes( state,
|
|||||||
if v != v_old:
|
if v != v_old:
|
||||||
changes.append(k)
|
changes.append(k)
|
||||||
|
|
||||||
state["config_changes"] = changes
|
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir, reload_needed
|
||||||
state["config_new"] = server_config
|
|
||||||
state["config_old"] = old_server_config
|
|
||||||
|
|
||||||
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
|
|
||||||
attention_mode = server_config["attention_mode"]
|
attention_mode = server_config["attention_mode"]
|
||||||
profile = server_config["profile"]
|
profile = server_config["profile"]
|
||||||
compile = server_config["compile"]
|
compile = server_config["compile"]
|
||||||
@ -995,7 +1012,7 @@ def apply_changes( state,
|
|||||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
state["_reload_needed"] = True
|
reload_needed = True
|
||||||
|
|
||||||
|
|
||||||
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
||||||
@ -1013,7 +1030,7 @@ def save_video(final_frames, output_path, fps=24):
|
|||||||
|
|
||||||
def build_callback(taskid, state, pipe, num_inference_steps, repeats):
|
def build_callback(taskid, state, pipe, num_inference_steps, repeats):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
def update_progress(step_idx, latents, read_state = False):
|
def update_progress(step_idx, _):
|
||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
step_idx += 1
|
step_idx += 1
|
||||||
if state.get("abort", False):
|
if state.get("abort", False):
|
||||||
@ -1094,8 +1111,7 @@ def generate_video(
|
|||||||
progress=gr.Progress() #track_tqdm= True
|
progress=gr.Progress() #track_tqdm= True
|
||||||
|
|
||||||
):
|
):
|
||||||
global wan_model, offloadobj, last_model_type
|
global wan_model, offloadobj, reload_needed, last_model_type
|
||||||
reload_needed = state.get("_reload_needed", False)
|
|
||||||
file_model_needed = model_needed(image2video)
|
file_model_needed = model_needed(image2video)
|
||||||
with lock:
|
with lock:
|
||||||
queue_not_empty = len(queue) > 0
|
queue_not_empty = len(queue) > 0
|
||||||
@ -1108,7 +1124,7 @@ def generate_video(
|
|||||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||||
wan_model, offloadobj, trans = load_models(image2video)
|
wan_model, offloadobj, trans = load_models(image2video)
|
||||||
print(f"Model loaded")
|
print(f"Model loaded")
|
||||||
state["_reload_needed"] = False
|
reload_needed= False
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -1121,11 +1137,12 @@ def generate_video(
|
|||||||
elif attention_mode in attention_modes_supported:
|
elif attention_mode in attention_modes_supported:
|
||||||
attn = attention_mode
|
attn = attention_mode
|
||||||
else:
|
else:
|
||||||
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
|
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
#if state.get("validate_success",0) != 1:
|
#if state.get("validate_success",0) != 1:
|
||||||
# return
|
# return
|
||||||
|
raw_resolution = resolution
|
||||||
width, height = resolution.split("x")
|
width, height = resolution.split("x")
|
||||||
width, height = int(width), int(height)
|
width, height = int(width), int(height)
|
||||||
|
|
||||||
@ -1289,7 +1306,7 @@ def generate_video(
|
|||||||
'total_steps': num_inference_steps,
|
'total_steps': num_inference_steps,
|
||||||
'start_time': time.time(),
|
'start_time': time.time(),
|
||||||
'last_update': time.time(),
|
'last_update': time.time(),
|
||||||
'repeats': f"0/{repeat_generation}",
|
'repeats': f"{video_no}/{repeat_generation}",
|
||||||
'status': "Encoding Prompt"
|
'status': "Encoding Prompt"
|
||||||
}
|
}
|
||||||
video_no += 1
|
video_no += 1
|
||||||
@ -1401,14 +1418,8 @@ def generate_video(
|
|||||||
normalize=True,
|
normalize=True,
|
||||||
value_range=(-1, 1))
|
value_range=(-1, 1))
|
||||||
|
|
||||||
configs = {
|
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||||
'prompt': prompt,
|
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
|
||||||
'negative_prompt': negative_prompt,
|
|
||||||
'resolution': resolution,
|
|
||||||
'video_length': video_length,
|
|
||||||
'seed': seed,
|
|
||||||
'num_inference_steps': num_inference_steps,
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata_choice = server_config.get("metadata_choice","metadata")
|
metadata_choice = server_config.get("metadata_choice","metadata")
|
||||||
if metadata_choice == "json":
|
if metadata_choice == "json":
|
||||||
@ -1715,19 +1726,14 @@ def switch_advanced(state, new_advanced, lset_name):
|
|||||||
else:
|
else:
|
||||||
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
||||||
|
|
||||||
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
||||||
|
|
||||||
if state.get("validate_success",0) != 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
loras_choices
|
|
||||||
loras = state["loras"]
|
loras = state["loras"]
|
||||||
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
|
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
|
||||||
|
|
||||||
ui_defaults = {
|
ui_settings = {
|
||||||
"prompts": prompt,
|
"prompts": prompt,
|
||||||
"image_prompt_type": image_prompt_type,
|
|
||||||
"resolution": resolution,
|
"resolution": resolution,
|
||||||
"video_length": video_length,
|
"video_length": video_length,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
@ -1747,10 +1753,25 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu
|
|||||||
"slg_start_perc": slg_start_perc,
|
"slg_start_perc": slg_start_perc,
|
||||||
"slg_end_perc": slg_end_perc
|
"slg_end_perc": slg_end_perc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if i2v:
|
||||||
|
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
|
||||||
|
ui_settings["image_prompt_type"] = image_prompt_type
|
||||||
|
else:
|
||||||
|
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
|
||||||
|
return ui_settings
|
||||||
|
|
||||||
|
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||||
|
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
||||||
|
if state.get("validate_success",0) != 1:
|
||||||
|
return
|
||||||
|
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||||
|
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
|
||||||
|
|
||||||
defaults_filename = get_settings_file_name(use_image2video)
|
defaults_filename = get_settings_file_name(use_image2video)
|
||||||
|
|
||||||
with open(defaults_filename, "w", encoding="utf-8") as f:
|
with open(defaults_filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(ui_defaults, f, indent=4)
|
json.dump(ui_settings , f, indent=4)
|
||||||
|
|
||||||
gr.Info("New Default Settings saved")
|
gr.Info("New Default Settings saved")
|
||||||
|
|
||||||
@ -1864,7 +1885,8 @@ def generate_video_tab(image2video=False):
|
|||||||
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
||||||
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
|
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
|
||||||
|
|
||||||
image_prompt_type = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =ui_defaults["image_prompt_type"], label="Location", show_label= False, scale= 3, visible=image2video)
|
image_prompt_type= ui_defaults.get("image_prompt_type",0)
|
||||||
|
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
|
||||||
|
|
||||||
if args.multiple_images:
|
if args.multiple_images:
|
||||||
image_to_continue = gr.Gallery(
|
image_to_continue = gr.Gallery(
|
||||||
@ -1876,9 +1898,9 @@ def generate_video_tab(image2video=False):
|
|||||||
if args.multiple_images:
|
if args.multiple_images:
|
||||||
image_to_end = gr.Gallery(
|
image_to_end = gr.Gallery(
|
||||||
label="Images as ending points for new videos", type ="pil", #file_types= "image",
|
label="Images as ending points for new videos", type ="pil", #file_types= "image",
|
||||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=False)
|
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
|
||||||
else:
|
else:
|
||||||
image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible= False)
|
image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
|
||||||
|
|
||||||
def switch_image_prompt_type_radio(image_prompt_type_radio):
|
def switch_image_prompt_type_radio(image_prompt_type_radio):
|
||||||
if args.multiple_images:
|
if args.multiple_images:
|
||||||
@ -1886,7 +1908,7 @@ def generate_video_tab(image2video=False):
|
|||||||
else:
|
else:
|
||||||
return gr.Image(visible = (image_prompt_type_radio == 1) )
|
return gr.Image(visible = (image_prompt_type_radio == 1) )
|
||||||
|
|
||||||
image_prompt_type.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type], outputs=[image_to_end])
|
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
|
||||||
|
|
||||||
|
|
||||||
advanced_prompt = advanced
|
advanced_prompt = advanced
|
||||||
@ -2080,7 +2102,7 @@ def generate_video_tab(image2video=False):
|
|||||||
outputs=[output]
|
outputs=[output]
|
||||||
)
|
)
|
||||||
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
||||||
save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
||||||
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
||||||
slg_start_perc, slg_end_perc ], outputs = [])
|
slg_start_perc, slg_end_perc ], outputs = [])
|
||||||
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||||
@ -2182,8 +2204,10 @@ def generate_configuration_tab():
|
|||||||
value=server_config.get("save_path", save_path)
|
value=server_config.get("save_path", save_path)
|
||||||
)
|
)
|
||||||
def check(mode):
|
def check(mode):
|
||||||
if not mode in attention_modes_supported:
|
if not mode in attention_modes_installed:
|
||||||
return " (NOT INSTALLED)"
|
return " (NOT INSTALLED)"
|
||||||
|
elif not mode in attention_modes_supported:
|
||||||
|
return " (NOT SUPPORTED)"
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
attention_choice = gr.Dropdown(
|
attention_choice = gr.Dropdown(
|
||||||
@ -2435,7 +2459,7 @@ def create_demo():
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
||||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||||
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||||
|
|
||||||
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
||||||
|
|||||||
@ -16,6 +16,6 @@ gradio>=5.0.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.3.0
|
mmgp==3.3.3
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
mutagen
|
mutagen
|
||||||
@ -177,7 +177,7 @@ class WanI2V:
|
|||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
|
|||||||
@ -30,6 +30,7 @@ try:
|
|||||||
max_seqlen_kv,
|
max_seqlen_kv,
|
||||||
):
|
):
|
||||||
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sageattn_varlen_wrapper = None
|
sageattn_varlen_wrapper = None
|
||||||
|
|
||||||
@ -38,11 +39,12 @@ import warnings
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
from .sage2_core import sageattn as alt_sageattn
|
from .sage2_core import sageattn as alt_sageattn, is_sage_supported
|
||||||
|
sage_supported = is_sage_supported()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sageattn = None
|
sageattn = None
|
||||||
alt_sageattn = None
|
alt_sageattn = None
|
||||||
|
sage_supported = False
|
||||||
# @torch.compiler.disable()
|
# @torch.compiler.disable()
|
||||||
def sageattn_wrapper(
|
def sageattn_wrapper(
|
||||||
qkv_list,
|
qkv_list,
|
||||||
@ -129,6 +131,14 @@ def get_attention_modes():
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def get_supported_attention_modes():
|
||||||
|
ret = get_attention_modes()
|
||||||
|
if not sage_supported:
|
||||||
|
if "sage" in ret:
|
||||||
|
ret.remove("sage")
|
||||||
|
if "sage2" in ret:
|
||||||
|
ret.remove("sage2")
|
||||||
|
return ret
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'pay_attention',
|
'pay_attention',
|
||||||
|
|||||||
@ -519,8 +519,11 @@ class CLIPModel:
|
|||||||
device=device)
|
device=device)
|
||||||
self.model = self.model.eval().requires_grad_(False)
|
self.model = self.model.eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
self.model.load_state_dict(
|
from mmgp import offload
|
||||||
torch.load(checkpoint_path, map_location='cpu'), assign= True)
|
# self.model.load_state_dict(
|
||||||
|
# torch.load(checkpoint_path, map_location='cpu'), assign= True)
|
||||||
|
|
||||||
|
offload.load_model_data(self.model, checkpoint_path.replace(".pth", "-bf16.safetensors"), writable_tensors= False)
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
self.tokenizer = HuggingfaceTokenizer(
|
self.tokenizer = HuggingfaceTokenizer(
|
||||||
|
|||||||
@ -51,6 +51,15 @@ from sageattention.quant import per_channel_fp8
|
|||||||
|
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
import warnings
|
import warnings
|
||||||
|
import os
|
||||||
|
|
||||||
|
def is_sage_supported():
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
for i in range(device_count):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
if major < 8:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def get_cuda_arch_versions():
|
def get_cuda_arch_versions():
|
||||||
cuda_archs = []
|
cuda_archs = []
|
||||||
|
|||||||
@ -496,7 +496,7 @@ class T5EncoderModel:
|
|||||||
device=device).eval().requires_grad_(False)
|
device=device).eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
offload.load_model_data(model,checkpoint_path )
|
offload.load_model_data(model,checkpoint_path, writable_tensors= False )
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
if shard_fn is not None:
|
if shard_fn is not None:
|
||||||
|
|||||||
@ -744,11 +744,12 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
|||||||
with torch.device('meta'):
|
with torch.device('meta'):
|
||||||
model = WanVAE_(**cfg)
|
model = WanVAE_(**cfg)
|
||||||
|
|
||||||
|
from mmgp import offload
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
logging.info(f'loading {pretrained_path}')
|
logging.info(f'loading {pretrained_path}')
|
||||||
model.load_state_dict(
|
# model.load_state_dict(
|
||||||
torch.load(pretrained_path, map_location=device), assign=True)
|
# torch.load(pretrained_path, map_location=device), assign=True)
|
||||||
|
offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -778,7 +779,7 @@ class WanVAE:
|
|||||||
self.model = _video_vae(
|
self.model = _video_vae(
|
||||||
pretrained_path=vae_pth,
|
pretrained_path=vae_pth,
|
||||||
z_dim=z_dim,
|
z_dim=z_dim,
|
||||||
).eval().requires_grad_(False).to(device)
|
).eval() #.requires_grad_(False).to(device)
|
||||||
|
|
||||||
def encode(self, videos, tile_size = 256, any_end_frame = False):
|
def encode(self, videos, tile_size = 256, any_end_frame = False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class WanT2V:
|
|||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user