diff --git a/README.md b/README.md index 82156c4..53af645 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models ## 🔥 Latest News!! +* May 18 2025: 👋 Wan 2.1GP v5.1 : Bonus Day, added LTX Video 13B Distilled: generate in less than one minute, very high quality Videos !\ * May 17 2025: 👋 Wan 2.1GP v5.0 : One App to Rule Them All !\ Added support for the other great open source architectures: - Hunyuan Video : text 2 video (one of the best, if not the best t2v) ,image 2 video and the recently released Hunyuan Custom (very good identify preservation when injecting a person into a video) diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index c96bf59..e9b7d8f 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -148,6 +148,7 @@ class LTXV: def __init__( self, model_filepath: str, + loras_filepath: str, text_encoder_filepath: str, dtype = torch.bfloat16, VAE_dtype = torch.bfloat16, @@ -155,7 +156,8 @@ class LTXV: ): self.mixed_precision_transformer = mixed_precision_transformer - self.distilled = "distilled" in model_filepath[0] + self.distilled = loras_filepath != None + # with safe_open(ckpt_path, framework="pt") as f: # metadata = f.metadata() # config_str = metadata.get("config") diff --git a/wgp.py b/wgp.py index 5cb139b..ac8d4ca 100644 --- a/wgp.py +++ b/wgp.py @@ -1528,7 +1528,7 @@ wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/w "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"] -ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_quanto_bf16_int8.safetensors"] +ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"] hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors", "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" ] @@ -1537,10 +1537,11 @@ transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan def get_dependent_models(model_filename, quantization, dtype_policy ): if "fantasy" in model_filename: return [get_model_filename("i2v_720p", quantization, dtype_policy)] + elif "ltxv_0.9.7_13B_distilled_lora128" in model_filename: + return [get_model_filename("ltxv_13B", quantization, dtype_policy)] else: return [] -# model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"] -model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "hunyuan", "hunyuan_i2v", "hunyuan_custom"] +model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "vace_14B", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", @@ -1609,10 +1610,10 @@ def get_model_name(model_filename, description_container = [""]): description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input." elif "ltxv_0.9.7_13B_dev" in model_filename: model_name = "LTX Video 0.9.7" - description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer." + description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer." elif "ltxv_0.9.7_13B_distilled" in model_filename: - model_name = "LTX Video 0.9.7 distilled" - description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This is the distilled / fast version. The LTX Video model expects very long prompt, so don't hesitate to use the Prompt Enhancer." + model_name = "LTX Video 0.9.7 Distilled" + description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer." elif "hunyuan_video_720" in model_filename: model_name = "Hunyuan Video text2video 720p" description = "Probably the best text 2 video model available." @@ -2065,13 +2066,14 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe -def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): +def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): filename = model_filename[-1] print(f"Loading '{filename}' model...") from ltx_video.ltxv import LTXV ltxv_model = LTXV( model_filepath = model_filename, + loras_filepath = loras_filenames, text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), dtype = dtype, # quantizeTransformer = quantizeTransformer, @@ -2119,21 +2121,28 @@ def get_transformer_model(model): def load_models(model_filename): - global transformer_filename + global transformer_filename, transformer_loras_filenames model_family = get_model_family(model_filename) perc_reserved_mem_max = args.perc_reserved_mem_max - model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + [model_filename] + new_transformer_loras_filenames = None + dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + if "_lora" in model_filename: + new_transformer_loras_filenames = [model_filename] + model_filelist = dependent_models + else: + model_filelist = dependent_models + [model_filename] for filename in model_filelist: download_models(filename) transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_filename = None + transformer_loras_filenames = None new_transformer_filename = model_filelist[-1] if model_family == "wan" : wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) elif model_family == "hunyuan": wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) else: @@ -2171,6 +2180,7 @@ def load_models(model_filename): if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_filename = new_transformer_filename + transformer_loras_filenames = new_transformer_loras_filenames return wan_model, offloadobj, pipe["transformer"] if not "P" in preload_model_policy: @@ -2735,7 +2745,7 @@ def generate_video( loras = state["loras"] - if len(loras) > 0: + if len(loras) > 0 or transformer_loras_filenames != None: def is_float(element: any) -> bool: if element is None: return False @@ -2770,6 +2780,9 @@ def generate_video( loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras] pinnedLora = profile !=5 #False # # # split_linear_modules_map = getattr(trans,"split_linear_modules_map", None) + if transformer_loras_filenames != None: + loras_selected += transformer_loras_filenames + list_mult_choices_nums.append(1.) offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, model_filename), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) errors = trans._loras_errors if len(errors) > 0: @@ -4335,6 +4348,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non trigger_refresh_input_type = gr.Text(interactive= False, visible= False) diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename + ltxv_distilled = "ltxv" in model_filename and "distilled" in model_filename recammaster = "recam" in model_filename vace = "Vace" in model_filename phantom = "phantom" in model_filename @@ -4550,7 +4564,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True) else: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) - with gr.Row(): + with gr.Row(visible = not ltxv_distilled) as inference_steps_row: num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps") @@ -4764,7 +4778,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, speed_tab, quality_tab, - sliding_window_tab, misc_tab, prompt_enhancer_row, + sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column, if update_form: locals_dict = locals() @@ -5626,7 +5640,7 @@ def create_demo(): theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main: - gr.Markdown("