diff --git a/README.md b/README.md
index c69848b..abd82fd 100644
--- a/README.md
+++ b/README.md
@@ -255,8 +255,9 @@ You can define multiple lines of macros. If there is only one macro line, the ap
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
--advanced : turn on the advanced mode while launching the app\
--i2v-settings : path to launch settings for i2v\
---t2v-settings : path to launch settings for t2v
---listen : make server accessible on network
+--t2v-settings : path to launch settings for t2v\
+--listen : make server accessible on network\
+--gpu device : run Wan on device for instance "cuda:1"
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :
diff --git a/gradio_server.py b/gradio_server.py
index 907620f..2da0769 100644
--- a/gradio_server.py
+++ b/gradio_server.py
@@ -14,7 +14,7 @@ import json
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils import cache_video
-from wan.modules.attention import get_attention_modes
+from wan.modules.attention import get_attention_modes, get_supported_attention_modes
import torch
import gc
import traceback
@@ -23,7 +23,7 @@ import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.3.1"
+target_mmgp_version = "3.3.2"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
@@ -169,6 +169,12 @@ def _parse_args():
default="",
help="Server name"
)
+ parser.add_argument(
+ "--gpu",
+ type=str,
+ default="",
+ help="Default GPU Device"
+ )
parser.add_argument(
"--open-browser",
@@ -262,12 +268,11 @@ def get_lora_dir(i2v):
return lora_dir_14B
return root_lora_dir
-attention_modes_supported = get_attention_modes()
-
+attention_modes_installed = get_attention_modes()
+attention_modes_supported = get_supported_attention_modes()
args = _parse_args()
args.flow_reverse = True
-
lock_ui_attention = False
lock_ui_transformer = False
lock_ui_compile = False
@@ -375,6 +380,7 @@ vae_config = server_config.get("vae_config", 0)
if len(args.vae_config) > 0:
vae_config = int(args.vae_config)
+reload_needed = False
default_ui = server_config.get("default_ui", "t2v")
metadata = server_config.get("metadata_type", "metadata")
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
@@ -474,7 +480,7 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "", ]
- fileList = [ [], ["Wan2.1_VAE.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
+ fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0:
@@ -491,6 +497,14 @@ def download_models(transformer_filename, text_encoder_filename):
offload.default_verboseLevel = verbose_level
+to_remove = ["models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", "Wan2.1_VAE.pth"]
+for file_name in to_remove:
+ file_name = os.path.join("ckpts",file_name)
+ if os.path.isfile(file_name):
+ try:
+ os.remove(file_name)
+ except:
+ pass
download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename)
@@ -663,6 +677,8 @@ def load_models(i2v):
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
+ if len(args.gpu) > 0:
+ torch.set_default_device(args.gpu)
return wan_model, offloadobj, pipe["transformer"]
@@ -702,8 +718,10 @@ def generate_header(model_filename, compile, attention_mode):
header += model_name
header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
- if attention_mode not in attention_modes_supported:
+ if attention_mode not in attention_modes_installed:
header += " -NOT INSTALLED-"
+ elif attention_mode not in attention_modes_supported:
+ header += " -NOT SUPPORTED-"
if compile:
header += ", pytorch compilation ON"
@@ -765,11 +783,7 @@ def apply_changes( state,
if v != v_old:
changes.append(k)
- state["config_changes"] = changes
- state["config_new"] = server_config
- state["config_old"] = old_server_config
-
- global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
+ global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir, reload_needed
attention_mode = server_config["attention_mode"]
profile = server_config["profile"]
compile = server_config["compile"]
@@ -781,7 +795,7 @@ def apply_changes( state,
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
pass
else:
- state["_reload_needed"] = True
+ reload_needed = True
yield "
The new configuration has been succesfully applied
"
@@ -931,8 +945,7 @@ def generate_video(
):
- global wan_model, offloadobj
- reload_needed = state.get("_reload_needed", False)
+ global wan_model, offloadobj, reload_needed
file_model_needed = model_needed(image2video)
if file_model_needed != model_filename or reload_needed:
del wan_model
@@ -943,7 +956,7 @@ def generate_video(
yield f"Loading model {get_model_name(file_model_needed)}..."
wan_model, offloadobj, trans = load_models(image2video)
yield f"Model loaded"
- state["_reload_needed"] = False
+ reload_needed= False
from PIL import Image
import numpy as np
@@ -956,11 +969,12 @@ def generate_video(
elif attention_mode in attention_modes_supported:
attn = attention_mode
else:
- gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
+ gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
return
if state.get("validate_success",0) != 1:
return
+ raw_resolution = resolution
width, height = resolution.split("x")
width, height = int(width), int(height)
@@ -1310,15 +1324,9 @@ def generate_video(
nrow=1,
normalize=True,
value_range=(-1, 1))
-
- configs = {
- 'prompt': prompt,
- 'negative_prompt': negative_prompt,
- 'resolution': resolution,
- 'video_length': video_length,
- 'seed': seed,
- 'num_inference_steps': num_inference_steps,
- }
+
+ configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+ loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
metadata_choice = server_config.get("metadata_choice","metadata")
if metadata_choice == "json":
@@ -1632,19 +1640,15 @@ def switch_advanced(state, new_advanced, lset_name):
else:
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
-def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+
+def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
- if state.get("validate_success",0) != 1:
- return
-
- loras_choices
loras = state["loras"]
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
- ui_defaults = {
+ ui_settings = {
"prompts": prompt,
- "image_prompt_type": image_prompt_type,
"resolution": resolution,
"video_length": video_length,
"num_inference_steps": num_inference_steps,
@@ -1664,6 +1668,24 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu
"slg_start_perc": slg_start_perc,
"slg_end_perc": slg_end_perc
}
+
+ if i2v:
+ ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
+ ui_settings["image_prompt_type"] = image_prompt_type,
+ else:
+ ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
+
+ return ui_settings
+
+def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
+
+ if state.get("validate_success",0) != 1:
+ return
+
+ ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
+
defaults_filename = get_settings_file_name(use_image2video)
with open(defaults_filename, "w", encoding="utf-8") as f:
@@ -1780,8 +1802,8 @@ def generate_video_tab(image2video=False):
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
-
- image_prompt_type = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =ui_defaults["image_prompt_type"], label="Location", show_label= False, scale= 3, visible=image2video)
+ image_prompt_type= ui_defaults.get("image_prompt_type",0)
+ image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
if args.multiple_images:
image_to_continue = gr.Gallery(
@@ -1793,9 +1815,9 @@ def generate_video_tab(image2video=False):
if args.multiple_images:
image_to_end = gr.Gallery(
label="Images as ending points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=False)
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
else:
- image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible= False)
+ image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
def switch_image_prompt_type_radio(image_prompt_type_radio):
if args.multiple_images:
@@ -1803,7 +1825,7 @@ def generate_video_tab(image2video=False):
else:
return gr.Image(visible = (image_prompt_type_radio == 1) )
- image_prompt_type.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type], outputs=[image_to_end])
+ image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
advanced_prompt = advanced
@@ -1973,7 +1995,7 @@ def generate_video_tab(image2video=False):
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
- save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
+ save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
slg_start_perc, slg_end_perc ], outputs = [])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@@ -2080,8 +2102,10 @@ def generate_configuration_tab():
value=server_config.get("save_path", save_path)
)
def check(mode):
- if not mode in attention_modes_supported:
+ if not mode in attention_modes_installed:
return " (NOT INSTALLED)"
+ elif not mode in attention_modes_supported:
+ return " (NOT SUPPORTED)"
else:
return ""
attention_choice = gr.Dropdown(
@@ -2238,7 +2262,7 @@ def create_demo():
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
- gr.Markdown("Wan 2.1GP v3.0 by DeepBeepMeep (Updates)
")
+ gr.Markdown("Wan 2.1GP v3.1 by DeepBeepMeep (Updates)
")
gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
diff --git a/requirements.txt b/requirements.txt
index 8a6f888..764015d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,6 +16,6 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
-mmgp==3.3.1
+mmgp==3.3.2
peft==0.14.0
mutagen
\ No newline at end of file
diff --git a/wan/image2video.py b/wan/image2video.py
index 506fe90..a853665 100644
--- a/wan/image2video.py
+++ b/wan/image2video.py
@@ -177,7 +177,7 @@ class WanI2V:
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
diff --git a/wan/modules/attention.py b/wan/modules/attention.py
index e40ac63..cbb578f 100644
--- a/wan/modules/attention.py
+++ b/wan/modules/attention.py
@@ -30,6 +30,7 @@ try:
max_seqlen_kv,
):
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+
except ImportError:
sageattn_varlen_wrapper = None
@@ -38,11 +39,12 @@ import warnings
try:
from sageattention import sageattn
- from .sage2_core import sageattn as alt_sageattn
+ from .sage2_core import sageattn as alt_sageattn, is_sage_supported
+ sage_supported = is_sage_supported()
except ImportError:
sageattn = None
alt_sageattn = None
-
+ sage_supported = False
# @torch.compiler.disable()
def sageattn_wrapper(
qkv_list,
@@ -129,6 +131,14 @@ def get_attention_modes():
return ret
+def get_supported_attention_modes():
+ ret = get_attention_modes()
+ if not sage_supported:
+ if "sage" in ret:
+ del ret["sage"]
+ if "sage2" in ret:
+ del ret["sage2"]
+ return ret
__all__ = [
'pay_attention',
diff --git a/wan/modules/clip.py b/wan/modules/clip.py
index 53b76c1..fc41d85 100644
--- a/wan/modules/clip.py
+++ b/wan/modules/clip.py
@@ -519,8 +519,11 @@ class CLIPModel:
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
- self.model.load_state_dict(
- torch.load(checkpoint_path, map_location='cpu'), assign= True)
+ from mmgp import offload
+ # self.model.load_state_dict(
+ # torch.load(checkpoint_path, map_location='cpu'), assign= True)
+
+ offload.load_model_data(self.model, checkpoint_path.replace(".pth", "-bf16.safetensors"), writable_tensors= False)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py
index d83d255..de94a60 100644
--- a/wan/modules/sage2_core.py
+++ b/wan/modules/sage2_core.py
@@ -51,6 +51,15 @@ from sageattention.quant import per_channel_fp8
from typing import Any, List, Literal, Optional, Tuple, Union
import warnings
+import os
+
+def is_sage_supported():
+ device_count = torch.cuda.device_count()
+ for i in range(device_count):
+ major, minor = torch.cuda.get_device_capability(i)
+ if major < 8:
+ return False
+ return True
def get_cuda_arch_versions():
cuda_archs = []
diff --git a/wan/modules/t5.py b/wan/modules/t5.py
index 5cccbbf..110e358 100644
--- a/wan/modules/t5.py
+++ b/wan/modules/t5.py
@@ -496,7 +496,7 @@ class T5EncoderModel:
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
from mmgp import offload
- offload.load_model_data(model,checkpoint_path )
+ offload.load_model_data(model,checkpoint_path, writable_tensors= False )
self.model = model
if shard_fn is not None:
diff --git a/wan/modules/vae.py b/wan/modules/vae.py
index e47e74c..67dcd9a 100644
--- a/wan/modules/vae.py
+++ b/wan/modules/vae.py
@@ -744,11 +744,12 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
with torch.device('meta'):
model = WanVAE_(**cfg)
+ from mmgp import offload
# load checkpoint
logging.info(f'loading {pretrained_path}')
- model.load_state_dict(
- torch.load(pretrained_path, map_location=device), assign=True)
-
+ # model.load_state_dict(
+ # torch.load(pretrained_path, map_location=device), assign=True)
+ offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
return model
@@ -778,7 +779,7 @@ class WanVAE:
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
- ).eval().requires_grad_(False).to(device)
+ ).eval() #.requires_grad_(False).to(device)
def encode(self, videos, tile_size = 256, any_end_frame = False):
"""
diff --git a/wan/text2video.py b/wan/text2video.py
index 4682a4a..88046db 100644
--- a/wan/text2video.py
+++ b/wan/text2video.py
@@ -89,7 +89,7 @@ class WanT2V:
from mmgp import offload
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)