From 71697bd7c56f5b3109a6f23c675c9d12bc814ea2 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 27 Mar 2025 16:49:05 +0100 Subject: [PATCH] Added Fun InP models support --- README.md | 11 +++- gradio_server.py | 152 ++++++++++++++++++++++++++++++--------------- wan/image2video.py | 21 ++++--- 3 files changed, 124 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index 298d9aa..2b96746 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! -* Mar 20 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below. -* Mar 19 2025: 👋 Wan2.1GP v3.2: +* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun) +* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below. +* Mar 24 2025: 👋 Wan2.1GP v3.2: - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\ Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star - Added back support for Pytorch compilation with Loras. It seems it had been broken for some time @@ -188,6 +189,10 @@ To run the image to video generator (in Low VRAM mode): ```bash python gradio_server.py --i2v ``` +To run the 1.3B Fun InP image to video generator (in Low VRAM mode): +```bash +python gradio_server.py --i2v-1-3B +``` To be able to input multiple images with the image to video generator: ```bash @@ -271,6 +276,8 @@ You can define multiple lines of macros. If there is only one macro line, the ap --t2v : launch the text to video generator (default defined in the configuration)\ --t2v-14B : launch the 14B model text to video generator\ --t2v-1-3B : launch the 1.3B model text to video generator\ +--i2v-14B : launch the 14B model image to video generator\ +--i2v-1-3B : launch the Fun InP 1.3B model image to video generator\ --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\ --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\ --lora-preset preset : name of preset gile (without the extension) to preload diff --git a/gradio_server.py b/gradio_server.py index 37a65a1..47c9f75 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -205,6 +205,18 @@ def _parse_args(): action="store_true", help="text to video mode 1.3B model" ) + parser.add_argument( + "--i2v-1-3B", + action="store_true", + help="Fun InP image to video mode 1.3B model" + ) + + parser.add_argument( + "--i2v-14B", + action="store_true", + help="image to video mode 14B model" + ) + parser.add_argument( "--compile", @@ -257,22 +269,21 @@ def get_lora_dir(i2v): root_lora_dir = "loras_i2v" if i2v else "loras" - if not i2v: - if "1.3B" in transformer_filename_t2v: - lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") - if os.path.isdir(lora_dir_1_3B ): - return lora_dir_1_3B - else: - lora_dir_14B = os.path.join(root_lora_dir, "14B") - if os.path.isdir(lora_dir_14B ): - return lora_dir_14B + if "1.3B" in (transformer_filename_i2v if i2v else transformer_filename_t2v) : + lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") + if os.path.isdir(lora_dir_1_3B ): + return lora_dir_1_3B + else: + lora_dir_14B = os.path.join(root_lora_dir, "14B") + if os.path.isdir(lora_dir_14B ): + return lora_dir_14B return root_lora_dir attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() args.flow_reverse = True - +# torch.backends.cuda.matmul.allow_fp16_accumulation = True lock_ui_attention = False lock_ui_transformer = False lock_ui_compile = False @@ -285,7 +296,7 @@ check_loras = args.check_loras ==1 advanced = args.advanced transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"] -transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"] +transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] server_config_filename = "gradio_config.json" @@ -293,7 +304,7 @@ server_config_filename = "gradio_config.json" if not Path(server_config_filename).is_file(): server_config = {"attention_mode" : "auto", "transformer_filename": transformer_choices_t2v[0], - "transformer_filename_i2v": transformer_choices_i2v[1], ######## + "transformer_filename_i2v": transformer_choices_i2v[1], "text_encoder_filename" : text_encoder_choices[1], "save_path": os.path.join(os.getcwd(), "gradio_outputs"), "compile" : "", @@ -363,7 +374,7 @@ def get_default_settings(filename, i2v): return ui_defaults transformer_filename_t2v = server_config["transformer_filename"] -transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) ######## +transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) text_encoder_filename = server_config["text_encoder_filename"] attention_mode = server_config["attention_mode"] @@ -395,11 +406,22 @@ if args.t2v_14B: transformer_filename_t2v = transformer_choices_t2v[2] lock_ui_transformer = False +if args.i2v_14B: + use_image2video = True + if not "14B" in transformer_filename_i2v: + transformer_filename_i2v = transformer_choices_t2v[3] + lock_ui_transformer = False + if args.t2v_1_3B: transformer_filename_t2v = transformer_choices_t2v[0] use_image2video = False lock_ui_transformer = False +if args.i2v_1_3B: + transformer_filename_i2v = transformer_choices_i2v[4] + use_image2video = True + lock_ui_transformer = False + only_allow_edit_in_advanced = False lora_preselected_preset = args.lora_preset lora_preselected_preset_for_i2v = use_image2video @@ -701,7 +723,10 @@ def get_default_flow(filename, i2v): def get_model_name(model_filename): - if "image" in model_filename: + if "Fun" in model_filename: + model_name = "Fun InP image2video" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + elif "image" in model_filename: model_name = "Wan2.1 image2video" model_name += " 720p" if "720p" in model_filename else " 480p" else: @@ -752,7 +777,7 @@ def apply_changes( state, global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = {"attention_mode" : attention_choice, "transformer_filename": transformer_choices_t2v[transformer_t2v_choice], - "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ########## + "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], "text_encoder_filename" : text_encoder_choices[text_encoder_choice], "save_path" : save_path_choice, "compile" : compile_choice, @@ -859,9 +884,9 @@ def refresh_gallery(state, txt): prompts_max = state.get("prompts_max",0) prompt_no = state.get("prompt_no",0) if prompts_max >1 : - label = f"Current Prompt ({prompt_no+1}/{prompts_max})" + label = f"Prompt ({prompt_no+1}/{prompts_max}) of Video being Generated" else: - label = f"Current Prompt" + label = f"Prompt of Video being Generated" return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label) @@ -998,7 +1023,7 @@ def generate_video( if slg_switch == 0: slg_layers = None if image2video: - if "480p" in transformer_filename_i2v and width * height > 848*480: + if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") return @@ -1007,10 +1032,9 @@ def generate_video( gr.Info(f"Resolution {resolution} not supported by image 2 video") return - else: - if "1.3B" in transformer_filename_t2v and width * height > 848*480: - gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P") - return + if "1.3B" in model_filename and width * height > 848*480: + gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") + return offload.shared_state["_attention"] = attn @@ -1142,7 +1166,7 @@ def generate_video( if len(list_mult_choices_nums ) < len(loras_choices): list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] - pinnedLora = False #profile !=5 #False # # # + pinnedLora = profile !=5 #False # # # offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors if len(errors) > 0: @@ -1273,7 +1297,8 @@ def generate_video( slg_start = slg_start/100, slg_end = slg_end/100, cfg_star_switch = cfg_star_switch, - cfg_zero_step = cfg_zero_step, + cfg_zero_step = cfg_zero_step, + add_frames_for_end_image = not "Fun" in transformer_filename_i2v ) else: @@ -2111,7 +2136,7 @@ def generate_video_tab(image2video=False): [state], [output , abort_btn, generate_btn, onemore_btn, gen_info] ) - return loras_choices, lset_name, header, state + return loras_column, loras_choices, presets_column, lset_name, header, state def generate_configuration_tab(): state_dict = {} @@ -2138,7 +2163,10 @@ def generate_configuration_tab(): ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0), ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1), ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2), - ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 3), + ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 3), + ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 4), + ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 5), + ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 6), ], value= index, label="Transformer model for Image to Video", @@ -2274,7 +2302,7 @@ def generate_configuration_tab(): ) def generate_about_tab(): - gr.Markdown("

Waw2.1GP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") + gr.Markdown("

Wan2.1GP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") gr.Markdown("Original Wan 2.1 Model by Alibaba (GitHub)") gr.Markdown("Many thanks to:") gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer") @@ -2293,29 +2321,51 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): new_i2v = evt.index == 1 use_image2video = new_i2v - if new_t2v: - lora_model_filename = t2v_state["loras_model"] - if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename): + if new_t2v or new_i2v: + state = i2v_state if new_i2v else t2v_state + lora_model_filename = state["loras_model"] + model_filename = model_needed(new_i2v) + if ("1.3B" in model_filename and not "1.3B" in lora_model_filename or "14B" in model_filename and not "14B" in lora_model_filename): lora_dir = get_lora_dir(new_i2v) loras, loras_names, loras_presets, _, _, _, _ = setup_loras(new_i2v, None, lora_dir, lora_preselected_preset, None) - t2v_state["loras"] = loras - t2v_state["loras_names"] = loras_names - t2v_state["loras_presets"] = loras_presets - t2v_state["loras_model"] = transformer_filename_t2v + state["loras"] = loras + state["loras_names"] = loras_names + state["loras_presets"] = loras_presets + state["loras_model"] = model_filename - t2v_advanced = t2v_state["advanced"] + advanced = state["advanced"] new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)] - lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(t2v_advanced), "")] - return [ - gr.Dropdown(choices=new_loras_choices, visible=len(loras_names)>0, value=[]), - gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(t2v_advanced), visible=len(loras_names)>0), - t2v_header, - gr.Dropdown(), - gr.Dropdown(), - i2v_header, - ] - return [gr.Dropdown(), gr.Dropdown(), t2v_header, - gr.Dropdown(), gr.Dropdown(), i2v_header] + lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(advanced), "")] + visible = len(loras_names)>0 + if new_t2v: + return [ + gr.Column(visible= visible), + gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]), + gr.Column(visible= visible), + gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible), + t2v_header, + gr.Column(), + gr.Dropdown(), + gr.Column(), + gr.Dropdown(), + i2v_header, + ] + else: + return [ + gr.Column(), + gr.Dropdown(), + gr.Column(), + gr.Dropdown(), + t2v_header, + gr.Column(visible= visible), + gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]), + gr.Column(visible= visible), + gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible), + i2v_header, + ] + + return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, + gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header] def create_demo(): @@ -2336,7 +2386,7 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: - gr.Markdown("

Wan 2.1GP v3.2 by DeepBeepMeep (Updates)

") + gr.Markdown("

Wan 2.1GP v3.3 by DeepBeepMeep (Updates)

") gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): @@ -2350,9 +2400,9 @@ def create_demo(): with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs: with gr.Tab("Text To Video", id="t2v") as t2v_tab: - t2v_loras_choices, t2v_lset_name, t2v_header, t2v_state = generate_video_tab() + t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_state = generate_video_tab() with gr.Tab("Image To Video", id="i2v") as i2v_tab: - i2v_loras_choices, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True) + i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True) if not args.lock_config: with gr.Tab("Configuration"): generate_configuration_tab() @@ -2362,8 +2412,8 @@ def create_demo(): fn=on_tab_select, inputs=[t2v_state, i2v_state], outputs=[ - t2v_loras_choices, t2v_lset_name, t2v_header, - i2v_loras_choices, i2v_lset_name, i2v_header + t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, + i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header ] ) return demo diff --git a/wan/image2video.py b/wan/image2video.py index 7a9285d..2ea4310 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -117,7 +117,8 @@ class WanI2V: logging.info(f"Creating WanModel from {model_filename}") from mmgp import offload - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) #forcedConfigPath= "ckpts/config2.json", + # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: @@ -169,6 +170,7 @@ class WanI2V: slg_end = 1.0, cfg_star_switch = True, cfg_zero_step = 5, + add_frames_for_end_image = True ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -212,8 +214,9 @@ class WanI2V: if any_end_frame: any_end_frame = True img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device) - frame_num +=1 - lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + if add_frames_for_end_image: + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) h, w = img.shape[1:] aspect_ratio = h / w @@ -237,7 +240,11 @@ class WanI2V: msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: msk[:, 1: -1] = 0 - msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) + if add_frames_for_end_image: + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) + else: + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) + else: msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) @@ -283,7 +290,7 @@ class WanI2V: torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16) ], dim=1).to(self.device) - lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame)[0] + lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] y = torch.concat([msk, lat_y]) @contextmanager @@ -441,9 +448,9 @@ class WanI2V: if self.rank == 0: # x0 = [lat_y] - video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame)[0] + video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] - if any_end_frame: + if any_end_frame and add_frames_for_end_image: # video[:, -1:] = img2_interpolated video = video[:, :-1]