mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Added LTX Video Distilled support
This commit is contained in:
parent
32473e3d11
commit
7c1ed43189
@ -21,6 +21,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* May 18 2025: 👋 Wan 2.1GP v5.1 : Bonus Day, added LTX Video 13B Distilled: generate in less than one minute, very high quality Videos !\
|
||||||
* May 17 2025: 👋 Wan 2.1GP v5.0 : One App to Rule Them All !\
|
* May 17 2025: 👋 Wan 2.1GP v5.0 : One App to Rule Them All !\
|
||||||
Added support for the other great open source architectures:
|
Added support for the other great open source architectures:
|
||||||
- Hunyuan Video : text 2 video (one of the best, if not the best t2v) ,image 2 video and the recently released Hunyuan Custom (very good identify preservation when injecting a person into a video)
|
- Hunyuan Video : text 2 video (one of the best, if not the best t2v) ,image 2 video and the recently released Hunyuan Custom (very good identify preservation when injecting a person into a video)
|
||||||
|
|||||||
@ -148,6 +148,7 @@ class LTXV:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_filepath: str,
|
model_filepath: str,
|
||||||
|
loras_filepath: str,
|
||||||
text_encoder_filepath: str,
|
text_encoder_filepath: str,
|
||||||
dtype = torch.bfloat16,
|
dtype = torch.bfloat16,
|
||||||
VAE_dtype = torch.bfloat16,
|
VAE_dtype = torch.bfloat16,
|
||||||
@ -155,7 +156,8 @@ class LTXV:
|
|||||||
):
|
):
|
||||||
|
|
||||||
self.mixed_precision_transformer = mixed_precision_transformer
|
self.mixed_precision_transformer = mixed_precision_transformer
|
||||||
self.distilled = "distilled" in model_filepath[0]
|
self.distilled = loras_filepath != None
|
||||||
|
|
||||||
# with safe_open(ckpt_path, framework="pt") as f:
|
# with safe_open(ckpt_path, framework="pt") as f:
|
||||||
# metadata = f.metadata()
|
# metadata = f.metadata()
|
||||||
# config_str = metadata.get("config")
|
# config_str = metadata.get("config")
|
||||||
|
|||||||
42
wgp.py
42
wgp.py
@ -1528,7 +1528,7 @@ wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/w
|
|||||||
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
||||||
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
||||||
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
|
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
|
||||||
ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_quanto_bf16_int8.safetensors"]
|
ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"]
|
||||||
|
|
||||||
hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors",
|
hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors",
|
||||||
"ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" ]
|
"ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" ]
|
||||||
@ -1537,10 +1537,11 @@ transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan
|
|||||||
def get_dependent_models(model_filename, quantization, dtype_policy ):
|
def get_dependent_models(model_filename, quantization, dtype_policy ):
|
||||||
if "fantasy" in model_filename:
|
if "fantasy" in model_filename:
|
||||||
return [get_model_filename("i2v_720p", quantization, dtype_policy)]
|
return [get_model_filename("i2v_720p", quantization, dtype_policy)]
|
||||||
|
elif "ltxv_0.9.7_13B_distilled_lora128" in model_filename:
|
||||||
|
return [get_model_filename("ltxv_13B", quantization, dtype_policy)]
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
# model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
|
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
|
||||||
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
|
|
||||||
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
||||||
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B",
|
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B",
|
||||||
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
||||||
@ -1609,10 +1610,10 @@ def get_model_name(model_filename, description_container = [""]):
|
|||||||
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
|
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
|
||||||
elif "ltxv_0.9.7_13B_dev" in model_filename:
|
elif "ltxv_0.9.7_13B_dev" in model_filename:
|
||||||
model_name = "LTX Video 0.9.7"
|
model_name = "LTX Video 0.9.7"
|
||||||
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer."
|
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer."
|
||||||
elif "ltxv_0.9.7_13B_distilled" in model_filename:
|
elif "ltxv_0.9.7_13B_distilled" in model_filename:
|
||||||
model_name = "LTX Video 0.9.7 distilled"
|
model_name = "LTX Video 0.9.7 Distilled"
|
||||||
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This is the distilled / fast version. The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer."
|
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer."
|
||||||
elif "hunyuan_video_720" in model_filename:
|
elif "hunyuan_video_720" in model_filename:
|
||||||
model_name = "Hunyuan Video text2video 720p"
|
model_name = "Hunyuan Video text2video 720p"
|
||||||
description = "Probably the best text 2 video model available."
|
description = "Probably the best text 2 video model available."
|
||||||
@ -2065,13 +2066,14 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
|
|||||||
pipe["text_encoder_2"] = wan_model.clip.model
|
pipe["text_encoder_2"] = wan_model.clip.model
|
||||||
return wan_model, pipe
|
return wan_model, pipe
|
||||||
|
|
||||||
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||||
filename = model_filename[-1]
|
filename = model_filename[-1]
|
||||||
print(f"Loading '{filename}' model...")
|
print(f"Loading '{filename}' model...")
|
||||||
from ltx_video.ltxv import LTXV
|
from ltx_video.ltxv import LTXV
|
||||||
|
|
||||||
ltxv_model = LTXV(
|
ltxv_model = LTXV(
|
||||||
model_filepath = model_filename,
|
model_filepath = model_filename,
|
||||||
|
loras_filepath = loras_filenames,
|
||||||
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
||||||
dtype = dtype,
|
dtype = dtype,
|
||||||
# quantizeTransformer = quantizeTransformer,
|
# quantizeTransformer = quantizeTransformer,
|
||||||
@ -2119,21 +2121,28 @@ def get_transformer_model(model):
|
|||||||
|
|
||||||
|
|
||||||
def load_models(model_filename):
|
def load_models(model_filename):
|
||||||
global transformer_filename
|
global transformer_filename, transformer_loras_filenames
|
||||||
model_family = get_model_family(model_filename)
|
model_family = get_model_family(model_filename)
|
||||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||||
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + [model_filename]
|
new_transformer_loras_filenames = None
|
||||||
|
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||||
|
if "_lora" in model_filename:
|
||||||
|
new_transformer_loras_filenames = [model_filename]
|
||||||
|
model_filelist = dependent_models
|
||||||
|
else:
|
||||||
|
model_filelist = dependent_models + [model_filename]
|
||||||
for filename in model_filelist:
|
for filename in model_filelist:
|
||||||
download_models(filename)
|
download_models(filename)
|
||||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||||
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
||||||
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
||||||
transformer_filename = None
|
transformer_filename = None
|
||||||
|
transformer_loras_filenames = None
|
||||||
new_transformer_filename = model_filelist[-1]
|
new_transformer_filename = model_filelist[-1]
|
||||||
if model_family == "wan" :
|
if model_family == "wan" :
|
||||||
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
elif model_family == "ltxv":
|
elif model_family == "ltxv":
|
||||||
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
elif model_family == "hunyuan":
|
elif model_family == "hunyuan":
|
||||||
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
else:
|
else:
|
||||||
@ -2171,6 +2180,7 @@ def load_models(model_filename):
|
|||||||
if len(args.gpu) > 0:
|
if len(args.gpu) > 0:
|
||||||
torch.set_default_device(args.gpu)
|
torch.set_default_device(args.gpu)
|
||||||
transformer_filename = new_transformer_filename
|
transformer_filename = new_transformer_filename
|
||||||
|
transformer_loras_filenames = new_transformer_loras_filenames
|
||||||
return wan_model, offloadobj, pipe["transformer"]
|
return wan_model, offloadobj, pipe["transformer"]
|
||||||
|
|
||||||
if not "P" in preload_model_policy:
|
if not "P" in preload_model_policy:
|
||||||
@ -2735,7 +2745,7 @@ def generate_video(
|
|||||||
|
|
||||||
|
|
||||||
loras = state["loras"]
|
loras = state["loras"]
|
||||||
if len(loras) > 0:
|
if len(loras) > 0 or transformer_loras_filenames != None:
|
||||||
def is_float(element: any) -> bool:
|
def is_float(element: any) -> bool:
|
||||||
if element is None:
|
if element is None:
|
||||||
return False
|
return False
|
||||||
@ -2770,6 +2780,9 @@ def generate_video(
|
|||||||
loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras]
|
loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras]
|
||||||
pinnedLora = profile !=5 #False # # #
|
pinnedLora = profile !=5 #False # # #
|
||||||
split_linear_modules_map = getattr(trans,"split_linear_modules_map", None)
|
split_linear_modules_map = getattr(trans,"split_linear_modules_map", None)
|
||||||
|
if transformer_loras_filenames != None:
|
||||||
|
loras_selected += transformer_loras_filenames
|
||||||
|
list_mult_choices_nums.append(1.)
|
||||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, model_filename), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
|
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, model_filename), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
|
||||||
errors = trans._loras_errors
|
errors = trans._loras_errors
|
||||||
if len(errors) > 0:
|
if len(errors) > 0:
|
||||||
@ -4335,6 +4348,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
|
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
|
||||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||||
ltxv = "ltxv" in model_filename
|
ltxv = "ltxv" in model_filename
|
||||||
|
ltxv_distilled = "ltxv" in model_filename and "distilled" in model_filename
|
||||||
recammaster = "recam" in model_filename
|
recammaster = "recam" in model_filename
|
||||||
vace = "Vace" in model_filename
|
vace = "Vace" in model_filename
|
||||||
phantom = "phantom" in model_filename
|
phantom = "phantom" in model_filename
|
||||||
@ -4550,7 +4564,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True)
|
video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True)
|
||||||
else:
|
else:
|
||||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
||||||
with gr.Row():
|
with gr.Row(visible = not ltxv_distilled) as inference_steps_row:
|
||||||
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
|
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
|
||||||
|
|
||||||
|
|
||||||
@ -4764,7 +4778,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
|
|
||||||
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
||||||
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, speed_tab, quality_tab,
|
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, speed_tab, quality_tab,
|
||||||
sliding_window_tab, misc_tab, prompt_enhancer_row,
|
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row,
|
||||||
video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column,
|
video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column,
|
||||||
if update_form:
|
if update_form:
|
||||||
locals_dict = locals()
|
locals_dict = locals()
|
||||||
@ -5626,7 +5640,7 @@ def create_demo():
|
|||||||
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
||||||
|
|
||||||
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
||||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v5.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v5.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||||
global model_list
|
global model_list
|
||||||
|
|
||||||
tab_state = gr.State({ "tab_no":0 })
|
tab_state = gr.State({ "tab_no":0 })
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user