Added LTX Video Distilled support

This commit is contained in:
DeepBeepMeep 2025-05-17 18:01:05 +02:00
parent 32473e3d11
commit 7c1ed43189
3 changed files with 32 additions and 15 deletions

View File

@ -21,6 +21,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
## 🔥 Latest News!!
* May 18 2025: 👋 Wan 2.1GP v5.1 : Bonus Day, added LTX Video 13B Distilled: generate in less than one minute, very high quality Videos !\
* May 17 2025: 👋 Wan 2.1GP v5.0 : One App to Rule Them All !\
Added support for the other great open source architectures:
- Hunyuan Video : text 2 video (one of the best, if not the best t2v) ,image 2 video and the recently released Hunyuan Custom (very good identify preservation when injecting a person into a video)

View File

@ -148,6 +148,7 @@ class LTXV:
def __init__(
self,
model_filepath: str,
loras_filepath: str,
text_encoder_filepath: str,
dtype = torch.bfloat16,
VAE_dtype = torch.bfloat16,
@ -155,7 +156,8 @@ class LTXV:
):
self.mixed_precision_transformer = mixed_precision_transformer
self.distilled = "distilled" in model_filepath[0]
self.distilled = loras_filepath != None
# with safe_open(ckpt_path, framework="pt") as f:
# metadata = f.metadata()
# config_str = metadata.get("config")

42
wgp.py
View File

@ -1528,7 +1528,7 @@ wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/w
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_quanto_bf16_int8.safetensors"]
ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"]
hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors",
"ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" ]
@ -1537,10 +1537,11 @@ transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan
def get_dependent_models(model_filename, quantization, dtype_policy ):
if "fantasy" in model_filename:
return [get_model_filename("i2v_720p", quantization, dtype_policy)]
elif "ltxv_0.9.7_13B_distilled_lora128" in model_filename:
return [get_model_filename("ltxv_13B", quantization, dtype_policy)]
else:
return []
# model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B",
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
@ -1609,10 +1610,10 @@ def get_model_name(model_filename, description_container = [""]):
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
elif "ltxv_0.9.7_13B_dev" in model_filename:
model_name = "LTX Video 0.9.7"
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer."
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer."
elif "ltxv_0.9.7_13B_distilled" in model_filename:
model_name = "LTX Video 0.9.7 distilled"
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This is the distilled / fast version. The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer."
model_name = "LTX Video 0.9.7 Distilled"
description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer."
elif "hunyuan_video_720" in model_filename:
model_name = "Hunyuan Video text2video 720p"
description = "Probably the best text 2 video model available."
@ -2065,13 +2066,14 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
pipe["text_encoder_2"] = wan_model.clip.model
return wan_model, pipe
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
filename = model_filename[-1]
print(f"Loading '{filename}' model...")
from ltx_video.ltxv import LTXV
ltxv_model = LTXV(
model_filepath = model_filename,
loras_filepath = loras_filenames,
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
dtype = dtype,
# quantizeTransformer = quantizeTransformer,
@ -2119,21 +2121,28 @@ def get_transformer_model(model):
def load_models(model_filename):
global transformer_filename
global transformer_filename, transformer_loras_filenames
model_family = get_model_family(model_filename)
perc_reserved_mem_max = args.perc_reserved_mem_max
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + [model_filename]
new_transformer_loras_filenames = None
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
if "_lora" in model_filename:
new_transformer_loras_filenames = [model_filename]
model_filelist = dependent_models
else:
model_filelist = dependent_models + [model_filename]
for filename in model_filelist:
download_models(filename)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
transformer_filename = None
transformer_loras_filenames = None
new_transformer_filename = model_filelist[-1]
if model_family == "wan" :
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
elif model_family == "ltxv":
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
elif model_family == "hunyuan":
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
else:
@ -2171,6 +2180,7 @@ def load_models(model_filename):
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
transformer_filename = new_transformer_filename
transformer_loras_filenames = new_transformer_loras_filenames
return wan_model, offloadobj, pipe["transformer"]
if not "P" in preload_model_policy:
@ -2735,7 +2745,7 @@ def generate_video(
loras = state["loras"]
if len(loras) > 0:
if len(loras) > 0 or transformer_loras_filenames != None:
def is_float(element: any) -> bool:
if element is None:
return False
@ -2770,6 +2780,9 @@ def generate_video(
loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras]
pinnedLora = profile !=5 #False # # #
split_linear_modules_map = getattr(trans,"split_linear_modules_map", None)
if transformer_loras_filenames != None:
loras_selected += transformer_loras_filenames
list_mult_choices_nums.append(1.)
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, model_filename), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
errors = trans._loras_errors
if len(errors) > 0:
@ -4335,6 +4348,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
diffusion_forcing = "diffusion_forcing" in model_filename
ltxv = "ltxv" in model_filename
ltxv_distilled = "ltxv" in model_filename and "distilled" in model_filename
recammaster = "recam" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename
@ -4550,7 +4564,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True)
else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
with gr.Row():
with gr.Row(visible = not ltxv_distilled) as inference_steps_row:
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
@ -4764,7 +4778,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, speed_tab, quality_tab,
sliding_window_tab, misc_tab, prompt_enhancer_row,
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row,
video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column,
if update_form:
locals_dict = locals()
@ -5626,7 +5640,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v5.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v5.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })