diff --git a/README.md b/README.md index 7bee91e..50bf85a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,14 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## πŸ”₯ Latest News!! -* Mar 18 2022: πŸ‘‹ Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to Tophness for his contributionsgit. You will need one more *pip install -r requirements.txt* to reflect new dependencies\ +* Mar 18 2022: πŸ‘‹ Wan2.1GP v3.0: + - New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app + - Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts. + - You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files) + - Slight acceleration with loras + You will need one more *pip install -r requirements.txt* + Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features +* Mar 18 2022: πŸ‘‹ Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\ * Mar 18 2022: πŸ‘‹ Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\ You will need to refresh the requirements with a *pip install -r requirements.txt* * Mar 17 2022: πŸ‘‹ Wan2.1GP v2.0: The Lora festival continues: @@ -243,11 +250,12 @@ You can define multiple lines of macros. If there is only one macro line, the ap --seed no : set default seed value\ --frames no : set the default number of frames to generate\ --steps no : set the default number of denoising steps\ ---res resolution : default resolution, choices=["480p", "720p", "823p", "1024p", "1280p"]\ --teacache speed multiplier: Tea cache speed multiplier, choices=["0", "1.5", "1.75", "2.0", "2.25", "2.5"]\ --slg : turn on skip layer guidance for improved quality\ --check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\ ---advanced : turn on the advanced mode while launching the app +--advanced : turn on the advanced mode while launching the app\ +--i2v-settings : path to launch settings for i2v\ +--t2v-settings : path to launch settings for t2v ### Profiles (for power users only) You can choose between 5 profiles, but two are really relevant here : diff --git a/gradio_server.py b/gradio_server.py index bf02272..65b14ce 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -23,7 +23,7 @@ import asyncio from wan.utils import prompt_parser PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.2.8" +target_mmgp_version = "3.3.0" from importlib.metadata import version mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: @@ -94,6 +94,20 @@ def _parse_args(): help="Lora preset to preload" ) + parser.add_argument( + "--i2v-settings", + type=str, + default="i2v_settings.json", + help="Path to settings file for i2v" + ) + + parser.add_argument( + "--t2v-settings", + type=str, + default="t2v_settings.json", + help="Path to settings file for t2v" + ) + # parser.add_argument( # "--lora-preset-i2v", # type=str, @@ -115,6 +129,27 @@ def _parse_args(): help="Verbose level" ) + parser.add_argument( + "--steps", + type=int, + default=0, + help="default denoising steps" + ) + + parser.add_argument( + "--frames", + type=int, + default=0, + help="default number of frames" + ) + + parser.add_argument( + "--seed", + type=int, + default=-1, + help="default generation seed" + ) + parser.add_argument( "--advanced", action="store_true", @@ -201,8 +236,16 @@ def _parse_args(): return args -def get_lora_dir(root_lora_dir): - if not use_image2video: +def get_lora_dir(i2v): + lora_dir =args.lora_dir + if i2v and len(lora_dir)==0: + lora_dir =args.lora_dir_i2v + if len(lora_dir) > 0: + return lora_dir + + root_lora_dir = "loras_i2v" if i2v else "loras" + + if not i2v: if "1.3B" in transformer_filename_t2v: lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") if os.path.isdir(lora_dir_1_3B ): @@ -256,19 +299,29 @@ else: text = reader.read() server_config = json.loads(text) -def get_defaults(): - global use_image2video, ui_defaults - defaults_filename = "i2v_defaults.json" if use_image2video else "t2v_defaults.json" +def get_settings_file_name(i2v): + return args.i2v_settings if i2v else args.t2v_settings + +def get_default_settings(filename, i2v): + def get_default_prompt(i2v): + if i2v: + return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." + else: + return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." + + defaults_filename = get_settings_file_name(i2v) if not Path(defaults_filename).is_file(): ui_defaults = { - "prompts": "", + "prompts": get_default_prompt(i2v), "resolution": "832x480", "video_length": 81, + "image_prompt_type" : 0, "num_inference_steps": 30, "seed": -1, "repeat_generation": 1, + "multi_images_gen_type": 0, "guidance_scale": 5.0, - "flow_shift": 5.0, + "flow_shift": get_default_flow(filename, i2v), "negative_prompt": "", "activated_loras": [], "loras_multipliers": "", @@ -286,6 +339,17 @@ def get_defaults(): with open(defaults_filename, "r", encoding="utf-8") as f: ui_defaults = json.load(f) + default_seed = args.seed + if default_seed > -1: + ui_defaults["seed"] = default_seed + default_number_frames = args.frames + if default_number_frames > 0: + ui_defaults["video_length"] = default_number_frames + default_number_steps = args.steps + if default_number_steps > 0: + ui_defaults["num_inference_steps"] = default_number_steps + return ui_defaults + transformer_filename_t2v = server_config["transformer_filename"] transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) ######## @@ -324,17 +388,9 @@ if args.t2v_1_3B: use_image2video = False lock_ui_transformer = False -get_defaults() only_allow_edit_in_advanced = False -lora_dir =args.lora_dir -if use_image2video and len(lora_dir)==0: - lora_dir =args.lora_dir_i2v -if len(lora_dir) ==0: - root_lora_dir = "loras_i2v" if use_image2video else "loras" -else: - root_lora_dir = lora_dir -lora_dir = get_lora_dir(root_lora_dir) lora_preselected_preset = args.lora_preset +lora_preselected_preset_for_i2v = use_image2video # if args.fast : #or args.fastest # transformer_filename_t2v = transformer_choices_t2v[2] # attention_mode="sage2" if "sage2" in attention_modes_supported else "sage" @@ -346,7 +402,8 @@ if args.compile: #args.fastest or compile="transformer" lock_ui_compile = True - +model_filename = "" +lora_model_filename = "" #attention_mode="sage" #attention_mode="sage2" #attention_mode="flash" @@ -441,6 +498,7 @@ def extract_preset(lset_name, loras): prompt ="" full_prompt ="" lset_name = sanitize_file_name(lset_name) + lora_dir = get_lora_dir(use_image2video) if not lset_name.endswith(".lset"): lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" ) else: @@ -472,24 +530,20 @@ def extract_preset(lset_name, loras): full_prompt = lset.get("full_prompt", False) return loras_choices, loras_mult_choices, prompt, full_prompt, error -def get_default_prompt(i2v): - if i2v: - return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." - else: - return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." -def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): +def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] loras_names = [] default_loras_choices = [] default_loras_multis_str = "" loras_presets = [] default_lora_preset = "" - default_prompt = "" + default_lora_preset_prompt = "" from pathlib import Path + lora_dir = get_lora_dir(i2v) if lora_dir != None : if not os.path.isdir(lora_dir): raise Exception("--lora-dir should be a path to a directory that contains Loras") @@ -505,7 +559,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo dir_presets.sort() loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets] - if check_loras: + if transformer !=None: loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=preprocess_loras, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, if len(loras) > 0: @@ -515,19 +569,17 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")): raise Exception(f"Unknown preset '{lora_preselected_preset}'") default_lora_preset = lora_preselected_preset - default_loras_choices, default_loras_multis_str, default_prompt, _ , error = extract_preset(default_lora_preset, loras) + default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(default_lora_preset, loras) if len(error) > 0: print(error[:200]) - if len(default_prompt) == 0: - default_prompt = get_default_prompt(use_image2video) - return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets + return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset def load_t2v_model(model_filename, value): cfg = WAN_CONFIGS['t2v-14B'] # cfg = WAN_CONFIGS['t2v-1.3B'] - print("load t2v model...") + print(f"Loading '{model_filename}' model...") wan_model = wan.WanT2V( config=cfg, @@ -547,9 +599,9 @@ def load_t2v_model(model_filename, value): def load_i2v_model(model_filename, value): + print(f"Loading '{model_filename}' model...") if value == '720P': - print("load 14B-720P i2v model...") cfg = WAN_CONFIGS['i2v-14B'] wan_model = wan.WanI2V( config=cfg, @@ -565,8 +617,7 @@ def load_i2v_model(model_filename, value): ) pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } # - if value == '480P': - print("load 14B-480P i2v model...") + elif value == '480P': cfg = WAN_CONFIGS['i2v-14B'] wan_model = wan.WanI2V( config=cfg, @@ -582,11 +633,16 @@ def load_i2v_model(model_filename, value): ) pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } # - + else: + raise Exception("Model i2v {value} not supported") return wan_model, pipe -def load_models(i2v, lora_dir, lora_preselected_preset): - model_filename = transformer_filename_i2v if i2v else transformer_filename_t2v +def model_needed(i2v): + return transformer_filename_i2v if i2v else transformer_filename_t2v + +def load_models(i2v): + global model_filename + model_filename = model_needed(i2v) download_models(model_filename, text_encoder_filename) if i2v: res720P = "720p" in model_filename @@ -599,20 +655,14 @@ def load_models(i2v, lora_dir, lora_preselected_preset): elif profile == 3: kwargs["budgets"] = { "*" : "70%" } offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs) - loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe["transformer"], lora_dir, lora_preselected_preset, None) - if "activated_loras" in ui_defaults: - lora_filenames = [os.path.basename(lora_path) for lora_path in loras] - activated_indices = [] - for lora_file in ui_defaults["activated_loras"]: - try: - idx = lora_filenames.index(lora_file) - activated_indices.append(str(idx)) - except ValueError: - print(f"Warning: Lora file {lora_file} from config not found in loras directory") - ui_defaults["activated_loras"] = activated_indices - return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets -wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) + return wan_model, offloadobj, pipe["transformer"] + +wan_model, offloadobj, transformer = load_models(use_image2video) +if check_loras: + setup_loras(use_image2video, transformer, get_lora_dir(use_image2video), "", None) + exit() + gen_in_progress = False def get_auto_attention(): @@ -621,20 +671,25 @@ def get_auto_attention(): return attn return "sdpa" -def get_default_flow(model_filename): - return 3.0 if "480p" in model_filename else 5.0 +def get_default_flow(filename, i2v): + return 7.0 if "480p" in filename and i2v else 5.0 + + +def get_model_name(model_filename): + if "image" in model_filename: + model_name = "Wan2.1 image2video" + model_name += " 720p" if "720p" in model_filename else " 480p" + else: + model_name = "Wan2.1 text2video" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + + return model_name def generate_header(model_filename, compile, attention_mode): - header = "

" - if "image" in model_filename: - model_name = "Wan2.1 image2video" - model_name += "720p" if "720p" in model_filename else "480p" - else: - model_name = "Wan2.1 text2video" - model_name += "14B" if "14B" in model_filename else "1.3B" + model_name = get_model_name(model_filename) header += model_name header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) @@ -666,7 +721,7 @@ def apply_changes( state, if gen_in_progress: yield "
Unable to change config when a generation is in progress
" return - global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets + global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = {"attention_mode" : attention_choice, "transformer_filename": transformer_choices_t2v[transformer_t2v_choice], "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ########## @@ -715,51 +770,13 @@ def apply_changes( state, vae_config = server_config["vae_config"] boost = server_config["boost"] if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ): - if "attention_mode" in changes: - pass - + pass else: - wan_model = None - offloadobj.release() - offloadobj = None - yield "
Please wait while the new configuration is being applied
" - lora_dir = get_lora_dir(root_lora_dir) - wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) + state["_reload_needed"] = True yield "
The new configuration has been succesfully applied
" - # return "
New Config file created. Please restart the Gradio Server
" - -def update_defaults(state, num_inference_steps,flow_shift, lset_name , loras_choices): - if "config_changes" not in state: - return get_default_flow("") - changes = state["config_changes"] - server_config = state["config_new"] - old_server_config = state["config_old"] - t2v_changed = False - if not use_image2video: - old_is_14B = "14B" in server_config["transformer_filename"] - new_is_14B = "14B" in old_server_config["transformer_filename"] - - trans_file = server_config["transformer_filename"] - t2v_changed = old_is_14B != new_is_14B - # num_inference_steps, flow_shift = get_default_flow(trans_file) - else: - old_is_720P = "720P" in server_config["transformer_filename_i2v"] - new_is_720P = "720P" in old_server_config["transformer_filename_i2v"] - trans_file = server_config["transformer_filename_i2v"] - if old_is_720P != new_is_720P: - num_inference_steps, flow_shift = get_default_flow(trans_file) - - header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] ) - new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)] - lset_choices = [ (preset, preset) for preset in loras_presets] - lset_choices.append( (get_new_preset_msg(advanced), "")) - if t2v_changed: - return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= []) - else: - return num_inference_steps, flow_shift, header, lset_name , loras_choices from moviepy.editor import ImageSequenceClip @@ -891,6 +908,7 @@ def generate_video( loras_choices, loras_mult_choices, image_to_continue, + image_to_end, video_to_continue, max_frames, RIFLEx_setting, @@ -904,15 +922,18 @@ def generate_video( ): - global use_image2video, wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets - if use_image2video != image2video: + global wan_model, offloadobj + reload_needed = state.get("_reload_needed", False) + file_model_needed = model_needed(image2video) + if file_model_needed != model_filename or reload_needed: if offloadobj is not None: offloadobj.release() offloadobj = None wan_model = None - lora_dir = get_lora_dir(root_lora_dir) - wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(image2video, lora_dir, lora_preselected_preset) - use_image2video = image2video + yield f"Loading model {get_model_name(file_model_needed)}..." + wan_model, offloadobj, trans = load_models(image2video) + yield f"Model loaded" + state["_reload_needed"] = False from PIL import Image import numpy as np @@ -991,23 +1012,41 @@ def generate_video( image_to_continue = [ tup[0] for tup in image_to_continue ] else: image_to_continue = [image_to_continue] + if image_to_end != None: + if isinstance(image_to_end , list): + image_to_end = [ tup[0] for tup in image_to_end ] + else: + image_to_end = [image_to_end ] + if len(image_to_continue) != len(image_to_end): + gr.Info("The number of start and end images should be the same ") + return if multi_images_gen_type == 0: new_prompts = [] new_image_to_continue = [] + new_image_to_end = [] for i in range(len(prompts) * len(image_to_continue) ): new_prompts.append( prompts[ i % len(prompts)] ) new_image_to_continue.append(image_to_continue[i // len(prompts)] ) + if image_to_end != None: + new_image_to_end.append(image_to_end[i // len(prompts)] ) prompts = new_prompts image_to_continue = new_image_to_continue + if image_to_end != None: + image_to_end = new_image_to_end else: if len(prompts) >= len(image_to_continue): if len(prompts) % len(image_to_continue) !=0: raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") rep = len(prompts) // len(image_to_continue) new_image_to_continue = [] + new_image_to_end = [] for i, _ in enumerate(prompts): new_image_to_continue.append(image_to_continue[i//rep] ) + if image_to_end != None: + new_image_to_end.append(image_to_end[i//rep] ) image_to_continue = new_image_to_continue + if image_to_end != None: + image_to_end = new_image_to_end else: if len(image_to_continue) % len(prompts) !=0: raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") @@ -1026,7 +1065,7 @@ def generate_video( else: input_image_or_video_path = None - + loras = state["loras"] if len(loras) > 0: def is_float(element: any) -> bool: if element is None: @@ -1060,7 +1099,7 @@ def generate_video( if len(list_mult_choices_nums ) < len(loras_choices): list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] - pinnedLora = False # profile !=5 + pinnedLora = profile !=5 #False # # # offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors if len(errors) > 0: @@ -1078,7 +1117,7 @@ def generate_video( # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 - joint_pass = boost ==1 + joint_pass = boost ==1 #and profile != 1 and profile != 3 # TeaCache trans.enable_teacache = tea_cache > 0 if trans.enable_teacache: @@ -1161,6 +1200,7 @@ def generate_video( samples = wan_model.generate( prompt, image_to_continue[no].convert('RGB'), + image_to_end[no].convert('RGB') if image_to_end != None else None, frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution], shift=flow_shift, @@ -1321,9 +1361,11 @@ def validate_save_lset(lset_name): def cancel_lset(): return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) -def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): - global loras_presets - + + +def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): + loras_presets = state["loras_presets"] + loras = state["loras"] if state.get("validate_success",0) == 0: pass if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False): @@ -1345,7 +1387,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ lset_name_filename = lset_name + ".lset" - full_lset_name_filename = os.path.join(lora_dir, lset_name_filename) + full_lset_name_filename = os.path.join(get_lora_dir(use_image2video), lset_name_filename) with open(full_lset_name_filename, "w", encoding="utf-8") as writer: writer.write(json.dumps(lset, indent=4)) @@ -1357,12 +1399,12 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem ) lset_choices = [ ( preset, preset) for preset in loras_presets ] lset_choices.append( (get_new_preset_msg(), "")) - + state["loras_presets"] = loras_presets return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) -def delete_lset(lset_name): - global loras_presets - lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" ) +def delete_lset(state, lset_name): + loras_presets = state["loras_presets"] + lset_name_filename = os.path.join( get_lora_dir(use_image2video), sanitize_file_name(lset_name) + ".lset" ) if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False): if not os.path.isfile(lset_name_filename): raise gr.Error(f"Preset '{lset_name}' not found ") @@ -1374,15 +1416,21 @@ def delete_lset(lset_name): pos = len(loras_presets) gr.Info(f"Choose a Preset to delete") + state["loras_presets"] = loras_presets + lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices.append((get_new_preset_msg(), "")) return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) -def refresh_lora_list(lset_name, loras_choices): - global loras,loras_names, loras_presets, wan_model +def refresh_lora_list(state, lset_name, loras_choices): + loras_names = state["loras_names"] prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices] - loras, loras_names, _, _, _, _, loras_presets = setup_loras(wan_model.model, lora_dir, lora_preselected_preset, None) + loras, loras_names, loras_presets, _, _, _, _ = setup_loras(use_image2video, None, get_lora_dir(use_image2video), lora_preselected_preset, None) + state["loras"] = loras + state["loras_names"] = loras_names + state["loras_presets"] = loras_presets + gc.collect() new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)] new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) } @@ -1393,7 +1441,7 @@ def refresh_lora_list(lset_name, loras_choices): lora_names_selected.append(lora_id) lset_choices = [ (preset, preset) for preset in loras_presets] - lset_choices.append((get_new_preset_msg(advanced), "")) + lset_choices.append((get_new_preset_msg( state["advanced"]), "")) if lset_name in loras_presets: pos = loras_presets.index(lset_name) else: @@ -1410,13 +1458,14 @@ def refresh_lora_list(lset_name, loras_choices): return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected) -def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt): +def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_mult_choices, prompt): state["apply_success"] = 0 if len(lset_name) == 0 or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False): gr.Info("Please choose a preset in the list or create one") else: + loras = state["loras"] loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras) if len(error) > 0: gr.Info(error) @@ -1430,12 +1479,12 @@ def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt): prompt = preset_prompt + '\n' + prompt gr.Info(f"Lora Preset '{lset_name}' has been applied") state["apply_success"] = 1 - state["wizard_prompt"] = 0 + wizard_prompt_activated = "on" - return loras_choices, loras_mult_choices, prompt + return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt -def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, *args): +def extract_prompt_from_wizard(state, variables_names, prompt, wizard_prompt, allow_null_values, *args): prompts = wizard_prompt.replace("\r" ,"").split("\n") @@ -1443,7 +1492,7 @@ def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, macro_already_written = False for prompt in prompts: if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt: - variables = state["variables"] + variables = variables_names.split("\n") values = args[:len(variables)] macro = "! " for i, (variable, value) in enumerate(zip(variables, values)): @@ -1464,14 +1513,14 @@ def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, prompt = "\n".join(new_prompts) return prompt, "" -def validate_wizard_prompt(state, prompt, wizard_prompt, *args): +def validate_wizard_prompt(state, wizard_prompt_activated, wizard_variables_names, prompt, wizard_prompt, *args): state["validate_success"] = 0 - if state.get("wizard_prompt",0) != 1: + if wizard_prompt_activated != "on": state["validate_success"] = 1 return prompt - prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, False, *args) + prompt, errors = extract_prompt_from_wizard(state, wizard_variables_names, prompt, wizard_prompt, False, *args) if len(errors) > 0: gr.Info(errors) return prompt @@ -1480,16 +1529,16 @@ def validate_wizard_prompt(state, prompt, wizard_prompt, *args): return prompt -def fill_prompt_from_wizard(state, prompt, wizard_prompt, *args): +def fill_prompt_from_wizard(state, wizard_prompt_activated, wizard_variables_names, prompt, wizard_prompt, *args): - if state.get("wizard_prompt",0) == 1: - prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, True, *args) + if wizard_prompt_activated == "on": + prompt, errors = extract_prompt_from_wizard(state, wizard_variables_names, prompt, wizard_prompt, True, *args) if len(errors) > 0: gr.Info(errors) - state["wizard_prompt"] = 0 + wizard_prompt_activated = "off" - return gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX + return wizard_prompt_activated, "", gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX def extract_wizard_prompt(prompt): variables = [] @@ -1523,22 +1572,22 @@ def extract_wizard_prompt(prompt): wizard_prompt = "\n".join(new_prompts) return wizard_prompt, variables, values, errors -def fill_wizard_prompt(state, prompt, wizard_prompt): +def fill_wizard_prompt(state, wizard_prompt_activated, prompt, wizard_prompt): def get_hidden_textboxes(num = PROMPT_VARS_MAX ): return [gr.Textbox(value="", visible=False)] * num hidden_column = gr.Column(visible = False) visible_column = gr.Column(visible = True) - if advanced or state.get("apply_success") != 1: - return prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes() + wizard_prompt_activated = "off" + if state["advanced"] or state.get("apply_success") != 1: + return wizard_prompt_activated, gr.Text(), prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes() prompt_parts= [] - state["wizard_prompt"] = 0 wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt) if len(errors) > 0: gr.Info( errors ) - return gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes() + return wizard_prompt_activated, "", gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes() for variable in variables: value = values.get(variable, "") @@ -1547,37 +1596,77 @@ def fill_wizard_prompt(state, prompt, wizard_prompt): prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts)) - state["variables"] = variables - state["wizard_prompt"] = 1 + variables_names= "\n".join(variables) + wizard_prompt_activated = "on" - return gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts + return wizard_prompt_activated, variables_names, gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts -def switch_prompt_type(state, prompt, wizard_prompt, *prompt_vars): - if advanced: - return fill_prompt_from_wizard(state, prompt, wizard_prompt, *prompt_vars) +def switch_prompt_type(state, wizard_prompt_activated_var, wizard_variables_names, prompt, wizard_prompt, *prompt_vars): + if state["advanced"]: + return fill_prompt_from_wizard(state, wizard_prompt_activated_var, wizard_variables_names, prompt, wizard_prompt, *prompt_vars) else: state["apply_success"] = 1 - return fill_wizard_prompt(state, prompt, wizard_prompt) + return fill_wizard_prompt(state, wizard_prompt_activated_var, prompt, wizard_prompt) visible= False -def switch_advanced(new_advanced, lset_name): - global advanced - advanced= new_advanced +def switch_advanced(state, new_advanced, lset_name): + state["advanced"] = new_advanced + loras_presets = state["loras_presets"] lset_choices = [ (preset, preset) for preset in loras_presets] - lset_choices.append((get_new_preset_msg(advanced), "")) + lset_choices.append((get_new_preset_msg(new_advanced), "")) if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="": - lset_name = get_new_preset_msg(advanced) + lset_name = get_new_preset_msg(new_advanced) if only_allow_edit_in_advanced: return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name) else: return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name) +def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc): + + if state.get("validate_success",0) != 1: + return + + loras_choices + loras = state["loras"] + activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] + + ui_defaults = { + "prompts": prompt, + "image_prompt_type": image_prompt_type, + "resolution": resolution, + "video_length": video_length, + "num_inference_steps": num_inference_steps, + "seed": seed, + "repeat_generation": repeat_generation, + "multi_images_gen_type": multi_images_gen_type, + "guidance_scale": guidance_scale, + "flow_shift": flow_shift, + "negative_prompt": negative_prompt, + "activated_loras": activated_loras, + "loras_multipliers": loras_mult_choices, + "tea_cache": tea_cache_setting, + "tea_cache_start_step_perc": tea_cache_start_step_perc, + "RIFLEx_setting": RIFLEx_setting, + "slg_switch": slg_switch, + "slg_layers": slg_layers, + "slg_start_perc": slg_start_perc, + "slg_end_perc": slg_end_perc + } + defaults_filename = get_settings_file_name(use_image2video) + + with open(defaults_filename, "w", encoding="utf-8") as f: + json.dump(ui_defaults, f, indent=4) + + gr.Info("New Default Settings saved") + def download_loras(): from huggingface_hub import snapshot_download - yield "Please wait while the Loras are being downloaded", *[gr.Column(visible=False)] * 2 + yield gr.Row(visible=True), "Please wait while the Loras are being downloaded", *[gr.Column(visible=False)] * 2 + lora_dir = get_lora_dir(True) log_path = os.path.join(lora_dir, "log.txt") - if not os.path.isfile(log_path) or True: + if not os.path.isfile(log_path): import shutil tmp_path = os.path.join(lora_dir, "tmp_lora_dowload") import glob @@ -1592,7 +1681,7 @@ def download_loras(): os.remove(tmp_path) except: pass - yield "Loras have been completely downloaded", *[gr.Column(visible=True)] * 2 + yield gr.Row(visible=True), "Loras have been completely downloaded", *[gr.Column(visible=True)] * 2 from datetime import datetime dt = datetime.today().strftime('%Y-%m-%d') @@ -1601,32 +1690,66 @@ def download_loras(): return def generate_video_tab(image2video=False): + filename = transformer_filename_i2v if image2video else transformer_filename_t2v + ui_defaults= get_default_settings(filename, image2video) + state_dict = {} - gr.Markdown("

Wan 2.1GP v2.1 by DeepBeepMeep (Updates / Original by Alibaba)

") - gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") - - with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): - gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :") - gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") - gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") - gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") - gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") - gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") + + state_dict["advanced"] = advanced + state_dict["loras_model"] = filename + preset_to_load = lora_preselected_preset if lora_preselected_preset_for_i2v == image2video else "" + + loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None) + + state_dict["loras"] = loras + state_dict["loras_presets"] = loras_presets + state_dict["loras_names"] = loras_names + + launch_prompt = "" + launch_preset = "" + launch_loras = [] + launch_multis_str = "" + + if len(default_lora_preset) > 0 and image2video == lora_preselected_preset_for_i2v: + launch_preset = default_lora_preset + launch_prompt = default_lora_preset_prompt + launch_loras = default_loras_choices + launch_multis_str = default_loras_multis_str + + if len(launch_prompt) == 0: + launch_prompt = ui_defaults["prompts"] + if len(launch_loras) == 0: + activated_loras = ui_defaults["activated_loras"] + launch_multis_str = ui_defaults["loras_multipliers"] + + if len(activated_loras) > 0: + lora_filenames = [os.path.basename(lora_path) for lora_path in loras] + activated_indices = [] + for lora_file in ui_defaults["activated_loras"]: + try: + idx = lora_filenames.index(lora_file) + activated_indices.append(str(idx)) + except ValueError: + print(f"Warning: Lora file {lora_file} from config not found in loras directory") + launch_loras = activated_indices + + + header = gr.Markdown(generate_header(model_filename, compile, attention_mode)) with gr.Row(visible= image2video): - with gr.Row(scale =3): - gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras). Dont't forget first to make a backup of your Loras just in case.") + with gr.Row(scale =2): + gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).") with gr.Row(scale =1): download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1) - with gr.Row(visible= image2video): + with gr.Row(scale =1): + gr.Markdown("") + with gr.Row(visible= image2video) as download_status_row: download_status = gr.Markdown() - model_filename = transformer_filename_i2v if image2video else transformer_filename_t2v - header = gr.Markdown(generate_header(model_filename, compile, attention_mode)) with gr.Row(): with gr.Column(): with gr.Row(visible= len(loras)>0) as presets_column: lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")] with gr.Column(scale=6): - lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset) + lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=launch_preset) with gr.Column(scale=1): with gr.Row(height=17): apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1) @@ -1647,35 +1770,62 @@ def generate_video_tab(image2video=False): delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) ####### + + image_prompt_type = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =ui_defaults["image_prompt_type"], label="Location", show_label= False, scale= 3, visible=image2video) + if args.multiple_images: image_to_continue = gr.Gallery( - label="Images as a starting point for new videos", type ="pil", #file_types= "image", + label="Images as starting points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video) else: image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video) + + if args.multiple_images: + image_to_end = gr.Gallery( + label="Images as ending points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=False) + else: + image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible= False) + + def switch_image_prompt_type_radio(image_prompt_type_radio): + if args.multiple_images: + return gr.Gallery(visible = (image_prompt_type_radio == 1) ) + else: + return gr.Image(visible = (image_prompt_type_radio == 1) ) + + image_prompt_type.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type], outputs=[image_to_end]) + + advanced_prompt = advanced prompt_vars=[] + if advanced_prompt: default_wizard_prompt, variables, values= None, None, None - else: - default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt) + else: + default_wizard_prompt, variables, values, errors = extract_wizard_prompt(launch_prompt) advanced_prompt = len(errors) > 0 with gr.Column(visible= advanced_prompt) as prompt_column_advanced: - prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=ui_defaults["prompts"], lines=3) + prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=launch_prompt, lines=3) with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: gr.Markdown("Please fill the following input fields to adapt automatically the Prompt:") + wizard_prompt_activated = "off" + wizard_variables = "" with gr.Row(): if not advanced_prompt: for variable in variables: value = values.get(variable, "") prompt_vars.append(gr.Textbox( placeholder=variable, min_width=80, show_label= False, info= variable, visible= True, value= "\n".join(value) )) - state_dict["wizard_prompt"] = 1 - state_dict["variables"] = variables + wizard_prompt_activated = "on" + if len(variables) > 0: + wizard_variables = "\n".join(variables) for _ in range( PROMPT_VARS_MAX - len(prompt_vars)): prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False)) + with gr.Column(not advanced_prompt) as prompt_column_wizard: wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3) + wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) + wizard_variables_var = gr.Text(wizard_variables, visible = False) state = gr.State(state_dict) with gr.Row(): if image2video: @@ -1721,11 +1871,11 @@ def generate_video_tab(image2video=False): seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)") with gr.Row(): repeat_generation = gr.Slider(1, 25.0, value=ui_defaults["repeat_generation"], step=1, label="Default Number of Generated Videos per Prompt") - multi_images_gen_type = gr.Dropdown( + multi_images_gen_type = gr.Dropdown( value=ui_defaults["multi_images_gen_type"], choices=[ - ("Generate every combination of images and texts prompts", 0), + ("Generate every combination of images and texts", 0), ("Match images and text prompts", 1), - ], visible= args.multiple_images, label= "Multiple Images as Prompts" + ], visible= args.multiple_images, label= "Multiple Images as Texts Prompts" ) with gr.Row(): guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults["guidance_scale"], step=0.5, label="Guidance Scale", visible=True) @@ -1739,11 +1889,11 @@ def generate_video_tab(image2video=False): choices=[ (lora_name, str(i) ) for i, lora_name in enumerate(loras_names) ], - value= ui_defaults["activated_loras"], + value= launch_loras, multiselect= True, label="Activated Loras" ) - loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=ui_defaults["loras_multipliers"]) + loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str) with gr.Row(): gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)") with gr.Row(): @@ -1797,8 +1947,10 @@ def generate_video_tab(image2video=False): with gr.Row(): slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start") slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end") - show_advanced.change(fn=switch_advanced, inputs=[show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( - fn=switch_prompt_type, inputs = [state, prompt, wizard_prompt, *prompt_vars], outputs = [prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) + with gr.Row(): + save_settings_btn = gr.Button("Set Settings as Default") + show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( + fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) with gr.Column(): gen_status = gr.Text(label="Status", interactive= False) output = gr.Gallery( @@ -1808,25 +1960,31 @@ def generate_video_tab(image2video=False): onemore_btn = gr.Button("One More Please !", visible= False) abort_btn = gr.Button("Abort") gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) + + + save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, + loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, + slg_start_perc, slg_end_perc ], outputs = []) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) - confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) + confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) - apply_lset_btn.click(apply_lset, inputs=[state, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]).then( - fn = fill_wizard_prompt, inputs = [state, prompt, wizard_prompt], outputs = [ prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] + apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_mult_choices, prompt]).then( + fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] ) - refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices]) - refresh_lora_btn2.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices]) - download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices]) + refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) + refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) + download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] ) abort_btn.click(abort_generation,state,abort_btn ) output.select(select_video, state, None ) onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state]) generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn] ).then( - fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt] + fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt] ).then( fn=generate_video, inputs=[ @@ -1846,6 +2004,7 @@ def generate_video_tab(image2video=False): loras_choices, loras_mult_choices, image_to_continue, + image_to_end, video_to_continue, max_frames, RIFLEx_setting, @@ -1862,12 +2021,12 @@ def generate_video_tab(image2video=False): [state], [output , abort_btn, generate_btn, onemore_btn, gen_info] ) - return loras_choices, lset_name + return loras_choices, lset_name, header, state -def generate_settings_tab(): +def generate_configuration_tab(): state_dict = {} state = gr.State(state_dict) - gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.") + gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") with gr.Column(): index = transformer_choices_t2v.index(transformer_filename_t2v) index = 0 if index ==0 else index @@ -1880,7 +2039,7 @@ def generate_settings_tab(): value= index, label="Transformer model for Text to Video", interactive= not lock_ui_transformer, - visible=not use_image2video + visible=True #not use_image2video ) index = transformer_choices_i2v.index(transformer_filename_i2v) index = 0 if index ==0 else index @@ -1894,7 +2053,7 @@ def generate_settings_tab(): value= index, label="Transformer model for Image to Video", interactive= not lock_ui_transformer, - visible = use_image2video, + visible = True # use_image2video, ) index = text_encoder_choices.index(text_encoder_filename) index = 0 if index ==0 else index @@ -2004,44 +2163,52 @@ def generate_settings_tab(): boost_choice, ], outputs= msg - ) -def on_tab_select(evt: gr.SelectData): - global use_image2video, wan_model, offloadobj, loras, loras_names, loras_presets, root_lora_dir, lora_dir - new_i2v = evt.index == 1 - if new_i2v == use_image2video: - return [gr.Dropdown(visible=not new_i2v), gr.Dropdown(visible=not new_i2v), - gr.Dropdown(visible=new_i2v), gr.Dropdown(visible=new_i2v), gr.Markdown()] - if wan_model is not None: - if offloadobj is not None: - offloadobj.release() - offloadobj = None - wan_model = None - gc.collect() - torch.cuda.empty_cache() - use_image2video = new_i2v - get_defaults() - lora_dir = args.lora_dir_i2v if new_i2v and args.lora_dir_i2v else args.lora_dir if not new_i2v and args.lora_dir else "loras_i2v" if new_i2v else "loras" - root_lora_dir = lora_dir - lora_dir = get_lora_dir(root_lora_dir) - wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset) - new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)] - lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(advanced), "")] - model_filename = transformer_filename_i2v if use_image2video else transformer_filename_t2v + ) + +def generate_about_tab(): + gr.Markdown("

Waw2.1GP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") + gr.Markdown("Original Wan 2.1 Model by Alibaba (GitHub)") + gr.Markdown("Many thanks to:") + gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer") + gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") + gr.Markdown("- Tophness : created multi tabs framework") + gr.Markdown("- Remade_AI : for creating their awesome Loras collection") + + +def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): + global lora_model_filename, use_image2video + + t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) + i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) + + new_t2v = evt.index == 0 + new_i2v = evt.index == 1 + use_image2video = new_i2v + + if new_t2v: + lora_model_filename = t2v_state["loras_model"] + if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename): + lora_dir = get_lora_dir(new_i2v) + loras, loras_names, loras_presets, _, _, _, _ = setup_loras(new_i2v, None, lora_dir, lora_preselected_preset, None) + t2v_state["loras"] = loras + t2v_state["loras_names"] = loras_names + t2v_state["loras_presets"] = loras_presets + t2v_state["loras_model"] = transformer_filename_t2v + + t2v_advanced = t2v_state["advanced"] + new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)] + lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(t2v_advanced), "")] + return [ + gr.Dropdown(choices=new_loras_choices, visible=len(loras_names)>0, value=[]), + gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(t2v_advanced), visible=len(loras_names)>0), + t2v_header, + gr.Dropdown(), + gr.Dropdown(), + i2v_header, + ] + return [gr.Dropdown(), gr.Dropdown(), t2v_header, + gr.Dropdown(), gr.Dropdown(), i2v_header] - if new_i2v: - return [ - gr.Dropdown(visible=False), - gr.Dropdown(visible=False), - gr.Dropdown(choices=new_loras_choices, value=ui_defaults['activated_loras'], visible=True), - gr.Dropdown(choices=lset_choices, value=default_lora_preset, visible=True) - ] - else: - return [ - gr.Dropdown(choices=new_loras_choices, value=ui_defaults['activated_loras'], visible=True), - gr.Dropdown(choices=lset_choices, value=default_lora_preset, visible=True), - gr.Dropdown(visible=False), - gr.Dropdown(visible=False) - ] def create_demo(): css = """ @@ -2061,19 +2228,33 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: - with gr.Tabs(selected=default_ui) as main_tabs: + gr.Markdown("

Wan 2.1GP v3.0 by DeepBeepMeep (Updates)

") + gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") + + with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): + gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :") + gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") + gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") + gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") + gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") + gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") + + + with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs: with gr.Tab("Text To Video", id="t2v") as t2v_tab: - t2v_loras_choices, t2v_lset_name = generate_video_tab() + t2v_loras_choices, t2v_lset_name, t2v_header, t2v_state = generate_video_tab() with gr.Tab("Image To Video", id="i2v") as i2v_tab: - i2v_loras_choices, i2v_lset_name = generate_video_tab(True) - with gr.Tab("Settings"): - generate_settings_tab() + i2v_loras_choices, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True) + with gr.Tab("Configuration"): + generate_configuration_tab() + with gr.Tab("About"): + generate_about_tab() main_tabs.select( fn=on_tab_select, - inputs=[], + inputs=[t2v_state, i2v_state], outputs=[ - t2v_loras_choices, t2v_lset_name, - i2v_loras_choices, i2v_lset_name + t2v_loras_choices, t2v_lset_name, t2v_header, + i2v_loras_choices, i2v_lset_name, i2v_header ] ) return demo diff --git a/requirements.txt b/requirements.txt index a9c9baf..b97e545 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ gradio>=5.0.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.2.8 +mmgp==3.3.0 peft==0.14.0 mutagen \ No newline at end of file diff --git a/wan/image2video.py b/wan/image2video.py index 088f51e..506fe90 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -26,6 +26,82 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed +from PIL import Image + +def lanczos(samples, width, height): + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] + result = torch.stack(images) + return result.to(samples.device, samples.dtype) + +def bislerp(samples, width, height): + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] + + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) + + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms + + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 + + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) + so = torch.sin(omega) + + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) + + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + +def common_upscale(samples, width, height, upscale_method, crop): + orig_shape = tuple(samples.shape) + if len(orig_shape) > 4: + samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1]) + samples = samples.movedim(2, 1) + samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1]) + if crop == "center": + old_width = samples.shape[-1] + old_height = samples.shape[-2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2) + else: + s = samples + + if upscale_method == "bislerp": + out = bislerp(s, width, height) + elif upscale_method == "lanczos": + out = lanczos(s, width, height) + else: + out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if len(orig_shape) == 4: + return out + + out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width)) + return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width)) + class WanI2V: def __init__( @@ -63,8 +139,8 @@ class WanI2V: Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + init_on_cpu (`bool`, *optional*, defaults to True): """ self.device = torch.device(f"cuda:{device_id}") self.config = config @@ -134,6 +210,7 @@ class WanI2V: def generate(self, input_prompt, img, + img2 = None, max_area=720 * 1280, frame_num=81, shift=5.0, @@ -188,8 +265,14 @@ class WanI2V: - W: Frame width from max_area) """ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1) + any_end_frame = img2 !=None + if any_end_frame: + any_end_frame = True + img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device) + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - F = frame_num h, w = img.shape[1:] aspect_ratio = h / w lat_h = round( @@ -201,28 +284,21 @@ class WanI2V: h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] - max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( - self.patch_size[1] * self.patch_size[2]) + max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - noise = torch.randn( - 16, - int((frame_num - 1)/4 + 1), #21, - lat_h, - lat_w, - dtype=torch.float32, - generator=seed_g, - device=self.device) + noise = torch.randn(16, lat_frames, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) - msk[:, 1:] = 0 - msk = torch.concat([ - torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] - ], - dim=1) + if any_end_frame: + msk[:, 1: -1] = 0 + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) + else: + msk[:, 1:] = 0 + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] @@ -242,7 +318,6 @@ class WanI2V: context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] - # self.clip.model.to(self.device) clip_context = self.clip.visual([img[:, None, :, :]]) if offload_model: self.clip.model.cpu() @@ -250,16 +325,24 @@ class WanI2V: from mmgp import offload offload.last_offload_obj.unload_all() - enc= torch.concat([ - torch.nn.functional.interpolate( - img[None].cpu(), size=(h, w), mode='bicubic').transpose( - 0, 1).to(torch.bfloat16), - torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16) - ], dim=1).to(self.device) - # enc = None + if any_end_frame: + img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16) + img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16) + mean2 = 0 + enc= torch.concat([ + img_interpolated, + torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16), + img2_interpolated, + ], dim=1).to(self.device) + else: + enc= torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16), + torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16) + ], dim=1).to(self.device) - y = self.vae.encode([enc], VAE_tile_size)[0] - y = torch.concat([msk, y]) + lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame)[0] + y = torch.concat([msk, lat_y]) @contextmanager def noop_no_sync(): @@ -293,7 +376,7 @@ class WanI2V: # sample videos latent = noise - freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx ) + freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx) arg_c = { 'context': [context[0]], @@ -344,8 +427,6 @@ class WanI2V: timestep = torch.stack(timestep).to(self.device) if joint_pass: - # if slg_layers is not None: - # raise ValueError('Can not use SLG and joint-pass') noise_pred_cond, noise_pred_uncond = self.model( latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) if self._interrupt: @@ -393,17 +474,25 @@ class WanI2V: del timestep if callback is not None: - callback(i, latent) + callback(i, latent) - x0 = [latent.to(self.device)] + x0 = [latent.to(self.device, dtype=torch.bfloat16)] if offload_model: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0, VAE_tile_size) + # x0 = [lat_y] + video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame)[0] + + if any_end_frame: + # video[:, -1:] = img2_interpolated + video = video[:, :-1] + + else: + video = None del noise, latent del sample_scheduler @@ -413,4 +502,4 @@ class WanI2V: if dist.is_initialized(): dist.barrier() - return videos[0] if self.rank == 0 else None + return video diff --git a/wan/modules/posemb_layers.py b/wan/modules/posemb_layers.py index 86b8078..3f2444a 100644 --- a/wan/modules/posemb_layers.py +++ b/wan/modules/posemb_layers.py @@ -429,11 +429,10 @@ def get_1d_rotary_pos_embed( ) # complex64 # [S, D/2] return freqs_cis -def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False): +def get_rotary_pos_embed(latents_size, enable_RIFLEx = False): target_ndim = 3 ndim = 5 - 2 - latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] patch_size = [1, 2, 2] if isinstance(patch_size, int): assert all(s % patch_size == 0 for s in latents_size), ( @@ -468,7 +467,7 @@ def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False): theta=10000, use_real=True, theta_rescale_factor=1, - L_test = (video_length - 1) // 4 + 1, + L_test = latents_size[0], enable_riflex = enable_RIFLEx ) return (freqs_cos, freqs_sin) \ No newline at end of file diff --git a/wan/modules/vae.py b/wan/modules/vae.py index fac20dd..e47e74c 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -530,26 +530,37 @@ class WanVAE_(nn.Module): x_recon = self.decode(z) return x_recon, mu, log_var - def encode(self, x, scale = None): + def encode(self, x, scale = None, any_end_frame = False): self.clear_cache() ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + if any_end_frame: + iter_ = 2 + (t - 2) // 4 + else: + iter_ = 1 + (t - 1) // 4 ## ε―ΉencodeθΎ“ε…₯ηš„xοΌŒζŒ‰ζ—Άι—΄ζ‹†εˆ†δΈΊ1、4、4、4.... + out_list = [] for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( + out_list.append(self.encoder( x[:, :, :1, :, :], feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_idx=self._enc_conv_idx)) + elif any_end_frame and i== iter_ -1: + out_list.append(self.encoder( + x[:, :, -1:, :, :], + feat_cache= None, + feat_idx=self._enc_conv_idx)) else: - out_ = self.encoder( + out_list.append(self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - out = torch.cat([out, out_], 2) + feat_idx=self._enc_conv_idx)) + self.clear_cache() + out = torch.cat(out_list, 2) + out_list = None mu, log_var = self.conv1(out).chunk(2, dim=1) if scale != None: @@ -558,11 +569,10 @@ class WanVAE_(nn.Module): 1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] - self.clear_cache() return mu - def decode(self, z, scale=None): + def decode(self, z, scale=None, any_end_frame = False): self.clear_cache() # z: [b,c,t,h,w] if scale != None: @@ -573,20 +583,26 @@ class WanVAE_(nn.Module): z = z / scale[1] + scale[0] iter_ = z.shape[2] x = self.conv2(z) + out_list = [] for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder( + out_list.append(self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_idx=self._conv_idx)) + elif any_end_frame and i==iter_-1: + out_list.append(self.decoder( + x[:, :, -1:, :, :], + feat_cache=None , + feat_idx=self._conv_idx)) else: - out_ = self.decoder( + out_list.append(self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, - feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) + feat_idx=self._conv_idx)) self.clear_cache() + out = torch.cat(out_list, 2) return out def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: @@ -601,7 +617,7 @@ class WanVAE_(nn.Module): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) return b - def spatial_tiled_decode(self, z, scale, tile_size): + def spatial_tiled_decode(self, z, scale, tile_size, any_end_frame= False): tile_sample_min_size = tile_size tile_latent_min_size = int(tile_sample_min_size / 8) tile_overlap_factor = 0.25 @@ -626,7 +642,7 @@ class WanVAE_(nn.Module): row = [] for j in range(0, z.shape[-1], overlap_size): tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] - decoded = self.decode(tile) + decoded = self.decode(tile, any_end_frame= any_end_frame) row.append(decoded) rows.append(row) result_rows = [] @@ -645,7 +661,7 @@ class WanVAE_(nn.Module): return torch.cat(result_rows, dim=-2) - def spatial_tiled_encode(self, x, scale, tile_size) : + def spatial_tiled_encode(self, x, scale, tile_size, any_end_frame = False) : tile_sample_min_size = tile_size tile_latent_min_size = int(tile_sample_min_size / 8) tile_overlap_factor = 0.25 @@ -660,7 +676,7 @@ class WanVAE_(nn.Module): row = [] for j in range(0, x.shape[-1], overlap_size): tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] - tile = self.encode(tile) + tile = self.encode(tile, any_end_frame= any_end_frame) row.append(tile) rows.append(row) result_rows = [] @@ -764,18 +780,18 @@ class WanVAE: z_dim=z_dim, ).eval().requires_grad_(False).to(device) - def encode(self, videos, tile_size = 256): + def encode(self, videos, tile_size = 256, any_end_frame = False): """ videos: A list of videos each with shape [C, T, H, W]. """ if tile_size > 0: - return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ] + return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: - return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ] + return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] - def decode(self, zs, tile_size): + def decode(self, zs, tile_size, any_end_frame = False): if tile_size > 0: - return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ] + return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: - return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ] + return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] diff --git a/wan/text2video.py b/wan/text2video.py index 385cdfd..4682a4a 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -241,7 +241,7 @@ class WanT2V: # sample videos latents = noise - freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx) + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}