added fast wan 5B and fixed pytorch compilation

This commit is contained in:
deepbeepmeep 2025-08-11 01:30:25 +02:00
parent 33514a89cb
commit ae5de88cbf
16 changed files with 105 additions and 35 deletions

View File

@ -20,6 +20,10 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates : ## 🔥 Latest Updates :
### August 8 2025: WanGP v7.75 - Faster than the VAE ...
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
### August 8 2025: WanGP v7.74 - Qwen Rebirth part 2 ### August 8 2025: WanGP v7.74 - Qwen Rebirth part 2
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.

View File

@ -0,0 +1,15 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.33.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v2_2",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"multitalk_output_dim": 768
}

View File

@ -0,0 +1,18 @@
{
"model":
{
"name": "Wan2.2 Multitalk 14B",
"architecture" : "i2v_2_2_multitalk",
"description": "The Multitalk module of Wan 2.1 has been combined with the Wan 2.2 image 2 video. It lets you have up to two people have a conversation.",
"modules": ["multitalk"],
"URLs": "i2v_2_2",
"URLs2": "i2v_2_2",
"group": "wan2_2",
"visible": false
},
"switch_threshold" : 900,
"guidance_scale" : 3.5,
"guidance2_scale" : 3.5,
"flow_shift" : 5
}

View File

@ -0,0 +1,15 @@
{
"model": {
"name": "Wan2.2 FastWan TextImage2video 5B",
"architecture": "ti2v_2_2",
"description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution",
"URLs": "ti2v_2_2",
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2"
},
"video_length": 121,
"guidance_scale": 1,
"flow_shift": 3,
"num_inference_steps": 3,
"resolution": "1280x720"
}

View File

@ -32,7 +32,7 @@ class family_handler():
return {}, {} return {}, {}
@staticmethod @staticmethod
def get_rgb_factors(model_type): def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux") latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux")
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -72,7 +72,7 @@ class family_handler():
return {"hunyuan":(20, "Hunyuan Video")} return {"hunyuan":(20, "Hunyuan Video")}
@staticmethod @staticmethod
def get_rgb_factors(model_type): def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan") latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan")
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -38,7 +38,7 @@ class family_handler():
return {}, {} return {}, {}
@staticmethod @staticmethod
def get_rgb_factors(model_type): def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv") latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv")
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -469,13 +469,20 @@ class QwenImageTransformer2DModel(nn.Module):
def preprocess_loras(self, model_type, sd): def preprocess_loras(self, model_type, sd):
new_sd = {}
for k,v in sd.items(): first = next(iter(sd), None)
if k.startswith("transformer_blocks"): if first == None:
k = "diffusion_model." + k return sd
new_sd[k] = v if first.startswith("transformer_blocks"):
sd = new_sd new_sd = {}
return sd for k,v in sd.items():
if k.startswith("transformer_blocks"):
k = "diffusion_model." + k
new_sd[k] = v
sd = new_sd
return sd
else:
return sd
def __init__( def __init__(
self, self,

View File

@ -87,17 +87,17 @@ class WanAny2V:
dtype=config.t5_dtype, dtype=config.t5_dtype,
device=torch.device('cpu'), device=torch.device('cpu'),
checkpoint_path=text_encoder_filename, checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), tokenizer_path=os.path.join(checkpoint_dir, "umt5-xxl"),
shard_fn= None) shard_fn= None)
# base_model_type = "i2v2_2" # base_model_type = "i2v2_2"
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2"]: if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]:
self.clip = CLIPModel( self.clip = CLIPModel(
dtype=config.clip_dtype, dtype=config.clip_dtype,
device=self.device, device=self.device,
checkpoint_path=os.path.join(checkpoint_dir , checkpoint_path=os.path.join(checkpoint_dir ,
config.clip_checkpoint), config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir , "clip_vit_large_patch14"))
if base_model_type in ["ti2v_2_2"]: if base_model_type in ["ti2v_2_2"]:
@ -495,7 +495,7 @@ class WanAny2V:
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"] vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"]
phantom = model_type in ["phantom_1.3B", "phantom_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = model_type in ["fantasy"] fantasy = model_type in ["fantasy"]
multitalk = model_type in ["multitalk", "vace_multitalk_14B"] multitalk = model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
recam = model_type in ["recam_1.3B"] recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"] ti2v = model_type in ["ti2v_2_2"]
start_step_no = 0 start_step_no = 0
@ -505,7 +505,7 @@ class WanAny2V:
timestep_injection = False timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
# image2video # image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "flf2v_720p"]: if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False any_end_frame = False
if image_start is None: if image_start is None:
_ , preframes_count, height, width = input_video.shape _ , preframes_count, height, width = input_video.shape

View File

@ -10,7 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg) i2v_14B.update(wan_shared_cfg)
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl' i2v_14B.t5_tokenizer = 'umt5-xxl'
# clip # clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'

View File

@ -10,7 +10,7 @@ t2v_14B.update(wan_shared_cfg)
# t5 # t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_14B.t5_tokenizer = 'google/umt5-xxl' t2v_14B.t5_tokenizer = 'umt5-xxl'
# vae # vae
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'

View File

@ -10,7 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
# t5 # t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' t2v_1_3B.t5_tokenizer = 'umt5-xxl'
# vae # vae
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'

View File

@ -504,7 +504,6 @@ class T5EncoderModel:
else: else:
self.model.to(self.device) self.model.to(self.device)
# init tokenizer # init tokenizer
tokenizer_path= "google/umt5-xxl"
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean='whitespace') name=tokenizer_path, seq_len=text_len, clean='whitespace')

View File

@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
def test_class_i2v(base_model_type): def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", ] #"hunyuan_i2v", return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "i2v_2_2_multitalk" ] #"hunyuan_i2v",
def test_class_1_3B(base_model_type): def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@ -79,7 +79,7 @@ class family_handler():
extra_model_def["no_steps_skipping"] = True extra_model_def["no_steps_skipping"] = True
i2v = test_class_i2v(base_model_type) i2v = test_class_i2v(base_model_type)
extra_model_def["i2v_class"] = i2v extra_model_def["i2v_class"] = i2v
extra_model_def["multitalk_class"] = base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
extra_model_def["vace_class"] = vace_class extra_model_def["vace_class"] = vace_class
@ -118,7 +118,7 @@ class family_handler():
return ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", return ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B",
"t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", "t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B", "recam_1.3B",
"i2v", "i2v_2_2", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod @staticmethod
@ -133,6 +133,7 @@ class family_handler():
"vace_14B" : [ "vace_multitalk_14B"], "vace_14B" : [ "vace_multitalk_14B"],
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"],
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
"i2v_2_2" : ["i2v_2_2_multitalk"],
"fantasy": ["multitalk"], "fantasy": ["multitalk"],
} }
return models_eqv_map, models_comp_map return models_eqv_map, models_comp_map
@ -150,9 +151,9 @@ class family_handler():
return 32 if base_model_type == "ti2v_2_2" else 16 return 32 if base_model_type == "ti2v_2_2" else 16
@staticmethod @staticmethod
def get_rgb_factors(model_type): def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors from shared.RGB_factors import get_rgb_factors
if model_type == "ti2v_2_2": return None, None if base_model_type == "ti2v_2_2": return None, None
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan") latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan")
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -17,7 +17,7 @@ gradio==5.23.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3
mmgp==3.5.7 mmgp==3.5.8
peft==0.15.0 peft==0.15.0
mutagen mutagen
pydantic==2.10.6 pydantic==2.10.6

29
wgp.py
View File

@ -44,13 +44,16 @@ from preprocessing.matanyone import app as matanyone_app
from tqdm import tqdm from tqdm import tqdm
import requests import requests
# import torch._dynamo as dynamo
# dynamo.config.recompile_limit = 2000 # default is 256
# dynamo.config.accumulated_recompile_limit = 2000 # or whatever limit you want
global_queue_ref = [] global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip" AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.7" target_mmgp_version = "3.5.8"
WanGP_version = "7.74" WanGP_version = "7.75"
settings_version = 2.23 settings_version = 2.23
max_source_video_frames = 3000 max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -1616,7 +1619,8 @@ def _parse_args():
def get_lora_dir(model_type): def get_lora_dir(model_type):
model_family = get_model_family(model_type) model_family = get_model_family(model_type)
i2v = test_class_i2v(model_type) and not get_base_model_type(model_type) == "i2v_2_2" base_model_type = get_base_model_type(model_type)
i2v = test_class_i2v(model_type) and base_model_type == "i2v_2_2"
if model_family == "wan": if model_family == "wan":
lora_dir =args.lora_dir lora_dir =args.lora_dir
if i2v and len(lora_dir)==0: if i2v and len(lora_dir)==0:
@ -1629,7 +1633,7 @@ def get_lora_dir(model_type):
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
if os.path.isdir(lora_dir_1_3B ): if os.path.isdir(lora_dir_1_3B ):
return lora_dir_1_3B return lora_dir_1_3B
elif model_type == "ti2v_2_2": elif base_model_type == "ti2v_2_2":
lora_dir_5B = os.path.join(root_lora_dir, "5B") lora_dir_5B = os.path.join(root_lora_dir, "5B")
if os.path.isdir(lora_dir_5B ): if os.path.isdir(lora_dir_5B ):
return lora_dir_5B return lora_dir_5B
@ -3641,7 +3645,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
def any_audio_track(model_type): def any_audio_track(model_type):
base_model_type = get_base_model_type(model_type) base_model_type = get_base_model_type(model_type)
return base_model_type in ["fantasy", "multitalk", "hunyuan_avatar", "hunyuan_custom_audio", "vace_multitalk_14B"] return base_model_type in ["fantasy", "hunyuan_avatar", "hunyuan_custom_audio"] or get_model_def(model_type).get("multitalk_class", False)
def get_available_filename(target_path, video_source, suffix = "", force_extension = None): def get_available_filename(target_path, video_source, suffix = "", force_extension = None):
name, extension = os.path.splitext(os.path.basename(video_source)) name, extension = os.path.splitext(os.path.basename(video_source))
@ -3950,6 +3954,12 @@ def generate_video(
model_filename, model_filename,
mode, mode,
): ):
# import os
# os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding
# import torch._logging as tlog
# tlog.set_logs(recompiles=True, guards=True, graph_breaks=True)
def remove_temp_filenames(temp_filenames_list): def remove_temp_filenames(temp_filenames_list):
for temp_filename in temp_filenames_list: for temp_filename in temp_filenames_list:
@ -4094,7 +4104,7 @@ def generate_video(
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
hunyuan_avatar = "hunyuan_video_avatar" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename
fantasy = base_model_type in ["fantasy"] fantasy = base_model_type in ["fantasy"]
multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] multitalk = model_def.get("multitalk_class", False)
flux = base_model_type in ["flux"] flux = base_model_type in ["flux"]
if "B" in audio_prompt_type or "X" in audio_prompt_type: if "B" in audio_prompt_type or "X" in audio_prompt_type:
@ -4821,8 +4831,9 @@ def generate_preview(model_type, latents):
import einops import einops
if latents is None: return None if latents is None: return None
model_handler = get_model_handler(model_type) model_handler = get_model_handler(model_type)
base_model_type = get_base_model_type(model_type)
if hasattr(model_handler, "get_rgb_factors"): if hasattr(model_handler, "get_rgb_factors"):
latent_rgb_factors, latent_rgb_factors_bias = model_handler.get_rgb_factors(model_type) latent_rgb_factors, latent_rgb_factors_bias = model_handler.get_rgb_factors(base_model_type )
else: else:
return None return None
if latent_rgb_factors is None: return None if latent_rgb_factors is None: return None
@ -5520,7 +5531,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if not test_any_sliding_window( base_model_type): if not test_any_sliding_window( base_model_type):
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"] pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"]
if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]: if not (base_model_type in ["fantasy"] or model_def.get("multitalk_class", False)):
pop += ["audio_guidance_scale", "speakers_locations"] pop += ["audio_guidance_scale", "speakers_locations"]
if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False): if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False):
@ -6505,7 +6516,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
vace = test_vace_module(base_model_type) vace = test_vace_module(base_model_type)
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = base_model_type in ["fantasy"] fantasy = base_model_type in ["fantasy"]
multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] multitalk = model_def.get("multitalk_class", False)
hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_video_custom = "hunyuan_video_custom" in model_filename hunyuan_video_custom = "hunyuan_video_custom" in model_filename