From fbf2793b2e5e3419a6bfa82690aac1b9a43c0cff Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 7 Apr 2025 02:15:13 +0200 Subject: [PATCH] New model selection logic / removed tabs --- README.md | 27 +- wan/utils/vace_preprocessor.py | 6 +- gradio_server.py => wgp.py | 1729 +++++++++++++++----------------- 3 files changed, 838 insertions(+), 924 deletions(-) rename gradio_server.py => wgp.py (69%) diff --git a/README.md b/README.md index 175c0cc..c0380c0 100644 --- a/README.md +++ b/README.md @@ -174,35 +174,35 @@ pip install -e . To run the text to video generator (in Low VRAM mode): ```bash -python gradio_server.py +python wgp.py.py #or -python gradio_server.py --t2v #launch the default text 2 video model +python wgp.py.py --t2v #launch the default text 2 video model #or -python gradio_server.py --t2v-14B #for the 14B model +python wgp.py.py --t2v-14B #for the 14B model #or -python gradio_server.py --t2v-1-3B #for the 1.3B model +python wgp.py.py --t2v-1-3B #for the 1.3B model ``` To run the image to video generator (in Low VRAM mode): ```bash -python gradio_server.py --i2v +python wgp.py.py --i2v ``` To run the 1.3B Fun InP image to video generator (in Low VRAM mode): ```bash -python gradio_server.py --i2v-1-3B +python wgp.py.py --i2v-1-3B ``` To be able to input multiple images with the image to video generator: ```bash -python gradio_server.py --i2v --multiple-images +python wgp.py.py --i2v --multiple-images ``` Within the application you can configure which video generator will be launched without specifying a command line switch. To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model ) ```bash -python gradio_server.py --profile 3 +python wgp.py.py --profile 3 ``` **Trouble shooting**:\ @@ -215,7 +215,7 @@ Therefore you may have no choice but to fallback to sdpa attention, to do so: or - Launch the application this way: ```bash -python gradio_server.py --attention sdpa +python wgp.py.py --attention sdpa ``` ### Loras support @@ -249,7 +249,7 @@ Each preset, is a file with ".lset" extension stored in the loras directory and Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server: ```bash -python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder +python wgp.py.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder ``` You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer. @@ -274,7 +274,7 @@ You can define multiple lines of macros. If there is only one macro line, the ap Vace is a ControlNet 1.3B text2video model that allows you on top of a text prompt to provide visual hints to guide the generation. It can do more things than image2video although it is not as good for just starting a video with an image because it only a 1.3B model (in fact 3B) versus 14B and (it is not specialized for start frames). However, with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... -First you need to switch the t2v model to Vace 1.3 in the Configuration Tab. Please note that Vace works well for the moment only with videos up to 5s (81 frames). +First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames). Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !): - reference Images: use this to inject people or objects in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background* @@ -296,6 +296,8 @@ There are lots of possible combinations. Some of them require to prepare some ma Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gradios) annotators / preprocessors Gradio tool that can help you build some of these materials depending on the task you want to achieve. There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck ! + +It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration ### Command line parameters for Gradio Server --i2v : launch the image to video generator\ --t2v : launch the text to video generator (default defined in the configuration)\ @@ -303,6 +305,7 @@ There is also a guide that describes the various combination of hints (https://g --t2v-1-3B : launch the 1.3B model text to video generator\ --i2v-14B : launch the 14B model image to video generator\ --i2v-1-3B : launch the Fun InP 1.3B model image to video generator\ +--vace : launch the Vace ControlNet 1.3B model image to video generator\ --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\ --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\ --lora-preset preset : name of preset gile (without the extension) to preload @@ -324,8 +327,6 @@ There is also a guide that describes the various combination of hints (https://g --slg : turn on skip layer guidance for improved quality\ --check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\ --advanced : turn on the advanced mode while launching the app\ ---i2v-settings : path to launch settings for i2v\ ---t2v-settings : path to launch settings for t2v\ --listen : make server accessible on network\ --gpu device : run Wan on device for instance "cuda:1" diff --git a/wan/utils/vace_preprocessor.py b/wan/utils/vace_preprocessor.py index 7c10719..912ae39 100644 --- a/wan/utils/vace_preprocessor.py +++ b/wan/utils/vace_preprocessor.py @@ -183,7 +183,7 @@ class VaceVideoProcessor(object): def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0): import math target_fps = self.max_fps - video_duration = frame_timestamps[-1][1] + video_frames_count = len(frame_timestamps) video_frame_duration = 1 /fps target_frame_duration = 1 / target_fps @@ -197,9 +197,9 @@ class VaceVideoProcessor(object): frame_ids.append(frame_no) cur_time += add_frames_count * video_frame_duration target_time += target_frame_duration - if cur_time > video_duration: + if frame_no >= video_frames_count -1: break - + frame_ids = frame_ids[:video_frames_count] x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box h, w = y2 - y1, x2 - x1 ratio = h / w diff --git a/gradio_server.py b/wgp.py similarity index 69% rename from gradio_server.py rename to wgp.py index 7f4e345..c81ab9a 100644 --- a/gradio_server.py +++ b/wgp.py @@ -21,7 +21,9 @@ import torch import gc import traceback import math +import typing import asyncio +import inspect from wan.utils import prompt_parser import base64 import io @@ -39,7 +41,6 @@ current_task_id = None task_id = 0 # progress_tracker = {} # tracker_lock = threading.Lock() -last_model_type = None def format_time(seconds): if seconds < 60: @@ -81,63 +82,41 @@ def pil_to_base64_uri(pil_image, format="png", quality=75): return None -def process_prompt_and_add_tasks( - prompt, - negative_prompt, - resolution, - video_length, - seed, - num_inference_steps, - guidance_scale, - flow_shift, - embedded_guidance_scale, - repeat_generation, - multi_images_gen_type, - tea_cache, - tea_cache_start_step_perc, - loras_choices, - loras_mult_choices, - image_prompt_type, - image_source1, - image_source2, - image_source3, - max_frames, - remove_background_image_ref, - temporal_upsampling, - spatial_upsampling, - RIFLEx_setting, - slg_switch, - slg_layers, - slg_start, - slg_end, - cfg_star_switch, - cfg_zero_step, - state, - image2video -): +def process_prompt_and_add_tasks(state, model_choice): if state.get("validate_success",0) != 1: - gr.Info("Validation failed, not adding tasks.") return state["validate_success"] = 0 + + model_filename = state["model_filename"] + + if model_choice != get_model_type(model_filename): + raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") + + inputs = state.get(get_model_type(model_filename), None) + inputs["state"] = state + if inputs == None: + return + prompt = inputs["prompt"] if len(prompt) ==0: return prompt, errors = prompt_parser.process_template(prompt) if len(errors) > 0: gr.Info("Error processing prompt template: " + errors) return + + inputs["model_filename"] = model_filename prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] if len(prompts) ==0: return - file_model_needed = model_needed(image2video) + resolution = inputs["resolution"] width, height = resolution.split("x") width, height = int(width), int(height) - if image2video: - - if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480: + if test_class_i2v(model_filename): + if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") return resolution = str(width) + "*" + str(height) @@ -145,133 +124,124 @@ def process_prompt_and_add_tasks( gr.Info(f"Resolution {resolution} not supported by image 2 video") return - if "1.3B" in file_model_needed and width * height > 848*480: + if "1.3B" in model_filename and width * height > 848*480: gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") return - - if not image2video: - if "Vace" in file_model_needed and "1.3B" in file_model_needed : + + + if "Vace" in model_filename: + video_prompt_type = inputs["video_prompt_type"] + image_refs = inputs["image_refs"] + video_guide = inputs["video_guide"] + video_mask = inputs["video_mask"] + if "Vace" in model_filename and "1.3B" in model_filename : resolution_reformated = str(height) + "*" + str(width) if not resolution_reformated in VACE_SIZE_CONFIGS: res = VACE_SIZE_CONFIGS.keys().join(" and ") gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") return + if not "I" in video_prompt_type: + image_refs = None + if not "V" in video_prompt_type: + video_guide = None + if not "M" in video_prompt_type: + video_mask = None - if not "I" in image_prompt_type: - image_source1 = None - if not "V" in image_prompt_type: - image_source2 = None - if not "M" in image_prompt_type: - image_source3 = None - - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] + if isinstance(image_refs, list): + image_refs = [ convert_image(tup[0]) for tup in image_refs ] from wan.utils.utils import resize_and_remove_background - image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1) + image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1) - image_source1 = [ image_source1 ] * len(prompts) - image_source2 = [ image_source2 ] * len(prompts) - image_source3 = [ image_source3 ] * len(prompts) - else: - if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0: + for single_prompt in prompts: + extra_inputs = { + "prompt" : single_prompt, + "image_refs": image_refs, + "video_guide" : video_guide, + "video_mask" : video_mask , + } + inputs.update(extra_inputs) + add_video_task(**inputs) + elif "image2video" in model_filename or "Fun_InP" in model_filename : + image_prompt_type = inputs["image_prompt_type"] + + image_start = inputs["image_start"] + image_end = inputs["image_end"] + if image_start == None or isinstance(image_start, list) and len(image_start) == 0: return - if image_prompt_type == 0: - image_source2 = None - if isinstance(image_source1, list): - image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] + if not "E" in image_prompt_type: + image_end = None + if isinstance(image_start, list): + image_start = [ convert_image(tup[0]) for tup in image_start ] else: - image_source1 = [convert_image(image_source1)] - if image_source2 != None: - if isinstance(image_source2 , list): - image_source2 = [ convert_image(tup[0]) for tup in image_source2 ] + image_start = [convert_image(image_start)] + if image_end != None: + if isinstance(image_end , list): + image_end = [ convert_image(tup[0]) for tup in image_end ] else: - image_source2 = [convert_image(image_source2) ] - if len(image_source1) != len(image_source2): + image_end = [convert_image(image_end) ] + if len(image_start) != len(image_end): gr.Info("The number of start and end images should be the same ") return - if multi_images_gen_type == 0: + if inputs["multi_images_gen_type"] == 0: new_prompts = [] - new_image_source1 = [] - new_image_source2 = [] - for i in range(len(prompts) * len(image_source1) ): + new_image_start = [] + new_image_end = [] + for i in range(len(prompts) * len(image_start) ): new_prompts.append( prompts[ i % len(prompts)] ) - new_image_source1.append(image_source1[i // len(prompts)] ) - if image_source2 != None: - new_image_source2.append(image_source2[i // len(prompts)] ) + new_image_start.append(image_start[i // len(prompts)] ) + if image_end != None: + new_image_end.append(image_end[i // len(prompts)] ) prompts = new_prompts - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 + image_start = new_image_start + if image_end != None: + image_end = new_image_end else: - if len(prompts) >= len(image_source1): - if len(prompts) % len(image_source1) !=0: + if len(prompts) >= len(image_start): + if len(prompts) % len(image_start) != 0: raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_source1) - new_image_source1 = [] - new_image_source2 = [] + rep = len(prompts) // len(image_start) + new_image_start = [] + new_image_end = [] for i, _ in enumerate(prompts): - new_image_source1.append(image_source1[i//rep] ) - if image_source2 != None: - new_image_source2.append(image_source2[i//rep] ) - image_source1 = new_image_source1 - if image_source2 != None: - image_source2 = new_image_source2 + new_image_start.append(image_start[i//rep] ) + if image_end != None: + new_image_end.append(image_end[i//rep] ) + image_start = new_image_start + if image_end != None: + image_end = new_image_end else: - if len(image_source1) % len(prompts) !=0: + if len(image_start) % len(prompts) !=0: raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_source1) // len(prompts) + rep = len(image_start) // len(prompts) new_prompts = [] - for i, _ in enumerate(image_source1): + for i, _ in enumerate(image_start): new_prompts.append( prompts[ i//rep] ) prompts = new_prompts - if image_source1 == None: - image_source1 = [None] * len(prompts) - if image_source2 == None: - image_source2 = [None] * len(prompts) - if image_source3 == None: - image_source3 = [None] * len(prompts) + if image_start == None: + image_start = [None] * len(prompts) + if image_end == None: + image_end = [None] * len(prompts) - for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) : - kwargs = { - "prompt" : single_prompt, - "negative_prompt" : negative_prompt, - "resolution" : resolution, - "video_length" : video_length, - "seed" : seed, - "num_inference_steps" : num_inference_steps, - "guidance_scale" : guidance_scale, - "flow_shift" : flow_shift, - "embedded_guidance_scale" : embedded_guidance_scale, - "repeat_generation" : repeat_generation, - "multi_images_gen_type" : multi_images_gen_type, - "tea_cache" : tea_cache, - "tea_cache_start_step_perc" : tea_cache_start_step_perc, - "loras_choices" : loras_choices, - "loras_mult_choices" : loras_mult_choices, - "image_prompt_type" : image_prompt_type, - "image_source1": image_source1, - "image_source2" : image_source2, - "image_source3" : image_source3 , - "max_frames" : max_frames, - "remove_background_image_ref" : remove_background_image_ref, - "temporal_upsampling" : temporal_upsampling, - "spatial_upsampling" : spatial_upsampling, - "RIFLEx_setting" : RIFLEx_setting, - "slg_switch" : slg_switch, - "slg_layers" : slg_layers, - "slg_start" : slg_start, - "slg_end" : slg_end, - "cfg_star_switch" : cfg_star_switch, - "cfg_zero_step" : cfg_zero_step, - "state" : state, - "image2video" : image2video - } - add_video_task(**kwargs) + for single_prompt, start, end in zip(prompts, image_start, image_end) : + extra_inputs = { + "prompt" : single_prompt, + "image_start": start, + "image_end" : end, + } + inputs.update(extra_inputs) + add_video_task(**inputs) + else: + for single_prompt in prompts : + extra_inputs = { + "prompt" : single_prompt, + } + inputs.update(extra_inputs) + add_video_task(**inputs) gen = get_gen_info(state) gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) @@ -282,29 +252,37 @@ def process_prompt_and_add_tasks( -def add_video_task(**kwargs): +def add_video_task(**inputs): global task_id - state = kwargs["state"] + state = inputs["state"] gen = get_gen_info(state) queue = gen["queue"] task_id += 1 current_task_id = task_id - start_image_data = kwargs["image_source1"] - start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data - end_image_data = kwargs["image_source2"] + inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] + start_image_data = None + end_image_data = None + for name in inputs_to_query: + image= inputs.get(name, None) + if image != None: + image= [image] if not isinstance(image, list) else image + if start_image_data == None: + start_image_data = image + else: + end_image_data = image + break queue.append({ "id": current_task_id, - "image2video": kwargs["image2video"], - "params": kwargs.copy(), - "repeats": kwargs["repeat_generation"], - "length": kwargs["video_length"], - "steps": kwargs["num_inference_steps"], - "prompt": kwargs["prompt"], + "params": inputs.copy(), + "repeats": inputs["repeat_generation"], + "length": inputs["video_length"], + "steps": inputs["num_inference_steps"], + "prompt": inputs["prompt"], "start_image_data": start_image_data, "end_image_data": end_image_data, - "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data], - "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70) + "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, + "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None }) return update_queue_data(queue) @@ -372,6 +350,7 @@ def get_queue_table(queue): start_img_uri =item.get('start_image_data_base64') start_img_uri = start_img_uri[0] if start_img_uri !=None else None end_img_uri = item.get('end_image_data_base64') + end_img_uri = end_img_uri[0] if end_img_uri !=None else None thumbnail_size = "50px" num_steps = item.get('steps') length = item.get('length') @@ -435,35 +414,6 @@ def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): """ return html -# def refresh_progress(): -# global current_task_id, progress_tracker, last_status_string -# task_id_to_check = current_task_id -# is_idle = True -# status_string = "Starting..." -# progress_percent = 0.0 -# html_content = "" - -# with tracker_lock: -# with lock: -# processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue) -# if task_id_to_check is not None: -# progress_data = progress_tracker.get(task_id_to_check) -# if progress_data: -# is_idle = False -# current_step = progress_data.get('current_step', 0) -# total_steps = progress_data.get('total_steps', 0) -# status = progress_data.get('status', "Starting...") -# repeats = progress_data.get("repeats", 1) - -# if total_steps > 0: -# progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps))) -# progress_percent = progress_float * 100 -# status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)" -# else: -# progress_percent = 0.0 -# status_string = f"{status} [{repeats}] - Initializing..." -# html_content = create_html_progress_bar(progress_percent, status_string, is_idle) -# return gr.update(value=html_content) def update_generation_status(html_content): if(html_content): @@ -534,19 +484,19 @@ def _parse_args(): help="Lora preset to preload" ) - parser.add_argument( - "--i2v-settings", - type=str, - default="i2v_settings.json", - help="Path to settings file for i2v" - ) + # parser.add_argument( + # "--i2v-settings", + # type=str, + # default="i2v_settings.json", + # help="Path to settings file for i2v" + # ) - parser.add_argument( - "--t2v-settings", - type=str, - default="t2v_settings.json", - help="Path to settings file for t2v" - ) + # parser.add_argument( + # "--t2v-settings", + # type=str, + # default="t2v_settings.json", + # help="Path to settings file for t2v" + # ) # parser.add_argument( # "--lora-preset-i2v", @@ -645,6 +595,12 @@ def _parse_args(): action="store_true", help="text to video mode 1.3B model" ) + + parser.add_argument( + "--vace-1-3B", + action="store_true", + help="Vace ControlNet 1.3B model" + ) parser.add_argument( "--i2v-1-3B", action="store_true", @@ -700,8 +656,9 @@ def _parse_args(): return args -def get_lora_dir(i2v): +def get_lora_dir(model_filename): lora_dir =args.lora_dir + i2v = test_class_i2v(model_filename) if i2v and len(lora_dir)==0: lora_dir =args.lora_dir_i2v if len(lora_dir) > 0: @@ -709,7 +666,7 @@ def get_lora_dir(i2v): root_lora_dir = "loras_i2v" if i2v else "loras" - if "1.3B" in (transformer_filename_i2v if i2v else transformer_filename_t2v) : + if "1.3B" in model_filename : lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") if os.path.isdir(lora_dir_1_3B ): return lora_dir_1_3B @@ -719,6 +676,7 @@ def get_lora_dir(i2v): return lora_dir_14B return root_lora_dir + attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() @@ -740,14 +698,14 @@ advanced = args.advanced transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] +transformer_choices = transformer_choices_t2v + transformer_choices_i2v text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] - server_config_filename = "gradio_config.json" if not Path(server_config_filename).is_file(): server_config = {"attention_mode" : "auto", - "transformer_filename": transformer_choices_t2v[0], - "transformer_filename_i2v": transformer_choices_i2v[1], + "transformer_type": "t2v", + "transformer_quantization": "int8", "text_encoder_filename" : text_encoder_choices[1], "save_path": os.path.join(os.getcwd(), "gradio_outputs"), "compile" : "", @@ -766,23 +724,68 @@ else: text = reader.read() server_config = json.loads(text) -def get_settings_file_name(i2v): - return args.i2v_settings if i2v else args.t2v_settings -def get_default_settings(filename, i2v): - def get_default_prompt(i2v): +model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp"] +model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", + "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B" } + + +def get_model_type(model_filename): + if "text2video" in model_filename and "14B" in model_filename: + return "t2v" + elif "text2video" in model_filename and "1.3B" in model_filename: + return "t2v_1.3B" + elif "Fun_InP" in model_filename and "1.3B" in model_filename: + return "fun_inp_1.3B" + elif "Fun_InP" in model_filename and "14B" in model_filename: + return "fun_inp" + elif "image2video_480p" in model_filename : + return "i2v" + elif "image2video_720p" in model_filename : + return "i2v_720p" + elif "Vace" in model_filename and "1.3B" in model_filename: + return "vace_1.3B" + elif "Vace" in model_filename and "14B" in model_filename: + return "vace" + else: + raise Exception("Unknown model:" + model_filename) + +def test_class_i2v(model_filename): + return "image2video" in model_filename or "Fun_InP" in model_filename + + +def get_model_filename(model_type, quantization): + signature = model_signatures[model_type] + + choices = [ name for name in transformer_choices if signature in name] + if len(quantization) == 0: + quantization = "bf16" + + if len(choices) <= 1: + return choices[0] + + sub_choices = [ name for name in choices if quantization in name] + if len(sub_choices) > 0: + return sub_choices[0] + else: + return choices[0] + +def get_settings_file_name(model_filename): + return get_model_type(model_filename) + "_settings.json" + +def get_default_settings(filename): + def get_default_prompt(i2v): if i2v: return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." else: return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." - - defaults_filename = get_settings_file_name(i2v) + i2v = "image2video" in file_name + defaults_filename = get_settings_file_name(filename) if not Path(defaults_filename).is_file(): ui_defaults = { "prompts": get_default_prompt(i2v), "resolution": "832x480", "video_length": 81, - "image_prompt_type" : 0 if i2v else "", "num_inference_steps": 30, "seed": -1, "repeat_generation": 1, @@ -805,6 +808,12 @@ def get_default_settings(filename, i2v): else: with open(defaults_filename, "r", encoding="utf-8") as f: ui_defaults = json.load(f) + prompts = ui_defaults.get("prompts", "") + if len(prompts) > 0: + ui_defaults["prompt"] = prompts + image_prompt_type = ui_defaults.get("image_prompt_type", None) + if image_prompt_type !=None and not isinstance(image_prompt_type, str): + ui_defaults["image_prompt_type"] = "S" if image_prompt_type == 0 else "SE" default_seed = args.seed if default_seed > -1: @@ -817,9 +826,9 @@ def get_default_settings(filename, i2v): ui_defaults["num_inference_steps"] = default_number_steps return ui_defaults -transformer_filename_t2v = server_config["transformer_filename"] -transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) - +transformer_type = server_config.get("transformer_type", "t2v") +transformer_quantization =server_config.get("transformer_quantization", "int8") +transformer_filename = get_model_filename(transformer_type, transformer_quantization) text_encoder_filename = server_config["text_encoder_filename"] attention_mode = server_config["attention_mode"] if len(args.attention)> 0: @@ -839,41 +848,27 @@ if len(args.vae_config) > 0: reload_needed = False default_ui = server_config.get("default_ui", "t2v") save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) -use_image2video = default_ui != "t2v" -if args.t2v: - use_image2video = False -if args.i2v: - use_image2video = True -if args.t2v_14B: - use_image2video = False - if not "14B" in transformer_filename_t2v: - transformer_filename_t2v = transformer_choices_t2v[2] - lock_ui_transformer = False +reload_model = server_config.get("reload_model", 2) -if args.i2v_14B: - use_image2video = True - if not "14B" in transformer_filename_i2v: - transformer_filename_i2v = transformer_choices_t2v[3] - lock_ui_transformer = False + +if args.t2v_14B or args.t2v: + transformer_filename = get_model_filename("t2v", transformer_quantization) + +if args.i2v_14B or args.i2v: + transformer_filename = get_model_filename("i2v", transformer_quantization) if args.t2v_1_3B: - transformer_filename_t2v = transformer_choices_t2v[0] - use_image2video = False - lock_ui_transformer = False + transformer_filename = get_model_filename("t2v_1.3B", transformer_quantization) if args.i2v_1_3B: - transformer_filename_i2v = transformer_choices_i2v[4] - use_image2video = True - lock_ui_transformer = False + transformer_filename = get_model_filename("fun_inp_1.3B", transformer_quantization) + +if args.vace_1_3B: + transformer_filename = get_model_filename("vace_1.3B", transformer_quantization) only_allow_edit_in_advanced = False lora_preselected_preset = args.lora_preset -# if args.fast : #or args.fastest -# transformer_filename_t2v = transformer_choices_t2v[2] -# attention_mode="sage2" if "sage2" in attention_modes_supported else "sage" -# default_tea_cache = 0.15 -# lock_ui_attention = True -# lock_ui_transformer = True +lora_preset_model = transformer_filename if args.compile: #args.fastest or compile="transformer" @@ -982,19 +977,19 @@ for file_name in to_remove: except: pass -download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename) +download_models(transformer_filename, text_encoder_filename) def sanitize_file_name(file_name, rep =""): return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep) -def extract_preset(image2video, lset_name, loras): +def extract_preset(model_filename, lset_name, loras): loras_choices = [] loras_choices_files = [] loras_mult_choices = "" prompt ="" full_prompt ="" lset_name = sanitize_file_name(lset_name) - lora_dir = get_lora_dir(image2video) + lora_dir = get_lora_dir(model_filename) if not lset_name.endswith(".lset"): lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" ) else: @@ -1028,7 +1023,7 @@ def extract_preset(image2video, lset_name, loras): -def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): +def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] loras_names = [] default_loras_choices = [] @@ -1039,7 +1034,7 @@ def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_line from pathlib import Path - lora_dir = get_lora_dir(i2v) + lora_dir = get_lora_dir(model_filename) if lora_dir != None : if not os.path.isdir(lora_dir): raise Exception("--lora-dir should be a path to a directory that contains Loras") @@ -1065,7 +1060,7 @@ def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_line if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")): raise Exception(f"Unknown preset '{lora_preselected_preset}'") default_lora_preset = lora_preselected_preset - default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(i2v, default_lora_preset, loras) + default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(model_filename, default_lora_preset, loras) if len(error) > 0: print(error[:200]) return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset @@ -1133,14 +1128,14 @@ def load_i2v_model(model_filename, value): raise Exception("Model i2v {value} not supported") return wan_model, pipe -def model_needed(i2v): - return transformer_filename_i2v if i2v else transformer_filename_t2v -def load_models(i2v): - global model_filename - model_filename = model_needed(i2v) + +def load_models(model_filename): + global transformer_filename + + transformer_filename = model_filename download_models(model_filename, text_encoder_filename) - if i2v: + if test_class_i2v(model_filename): res720P = "720p" in model_filename wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P") else: @@ -1159,11 +1154,15 @@ def load_models(i2v): return wan_model, offloadobj, pipe["transformer"] -wan_model, offloadobj, transformer = load_models(use_image2video) -if check_loras: - setup_loras(use_image2video, transformer, get_lora_dir(use_image2video), "", None) - exit() -del transformer +if reload_model ==3: + wan_model, offloadobj, transformer = None, None, None + reload_needed = True +else: + wan_model, offloadobj, transformer = load_models(transformer_filename) + if check_loras: + setup_loras(model_filename, transformer, get_lora_dir(transformer_filename), "", None) + exit() + del transformer gen_in_progress = False @@ -1182,7 +1181,7 @@ def get_model_name(model_filename): model_name = "Fun InP image2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" elif "Vace" in model_filename: - model_name = "Vace ControlNet text2video" + model_name = "Vace ControlNet" model_name += " 14B" if "14B" in model_filename else " 1.3B" elif "image" in model_filename: model_name = "Wan2.1 image2video" @@ -1193,14 +1192,30 @@ def get_model_name(model_filename): return model_name -def generate_header(model_filename, compile, attention_mode): +# def generate_header(model_filename, compile, attention_mode): - header = "

" +# header = "

" - model_name = get_model_name(model_filename) +# model_name = get_model_name(model_filename) - header += model_name - header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) +# header += model_name +# header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) +# if attention_mode not in attention_modes_installed: +# header += " -NOT INSTALLED-" +# elif attention_mode not in attention_modes_supported: +# header += " -NOT SUPPORTED-" + +# if compile: +# header += ", pytorch compilation ON" +# header += ")

" + + +# return header + + +def generate_header(compile, attention_mode): + + header = "
Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) if attention_mode not in attention_modes_installed: header += " -NOT INSTALLED-" elif attention_mode not in attention_modes_supported: @@ -1208,14 +1223,12 @@ def generate_header(model_filename, compile, attention_mode): if compile: header += ", pytorch compilation ON" - header += ")

" - + header += "
" return header def apply_changes( state, - transformer_t2v_choice, - transformer_i2v_choice, + transformer_type_choice, text_encoder_choice, save_path_choice, attention_choice, @@ -1223,7 +1236,7 @@ def apply_changes( state, profile_choice, vae_config_choice, metadata_choice, - default_ui_choice ="t2v", + quantization_choice, boost_choice = 1, clear_file_list = 0, reload_choice = 1 @@ -1235,15 +1248,14 @@ def apply_changes( state, return global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = {"attention_mode" : attention_choice, - "transformer_filename": transformer_choices_t2v[transformer_t2v_choice], - "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], + "transformer_type": transformer_type_choice, "text_encoder_filename" : text_encoder_choices[text_encoder_choice], "save_path" : save_path_choice, "compile" : compile_choice, "profile" : profile_choice, "vae_config" : vae_config_choice, "metadata_choice": metadata_choice, - "default_ui" : default_ui_choice, + "transformer_quantization" : quantization_choice, "boost" : boost_choice, "clear_file_list" : clear_file_list, "reload_model" : reload_choice, @@ -1255,7 +1267,6 @@ def apply_changes( state, old_server_config = json.loads(text) if lock_ui_transformer: server_config["transformer_filename"] = old_server_config["transformer_filename"] - server_config["transformer_filename_i2v"] = old_server_config["transformer_filename_i2v"] if lock_ui_attention: server_config["attention_mode"] = old_server_config["attention_mode"] if lock_ui_compile: @@ -1270,15 +1281,17 @@ def apply_changes( state, if v != v_old: changes.append(k) - global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir, reload_needed + global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type attention_mode = server_config["attention_mode"] profile = server_config["profile"] compile = server_config["compile"] - transformer_filename_t2v = server_config["transformer_filename"] - transformer_filename_i2v = server_config["transformer_filename_i2v"] text_encoder_filename = server_config["text_encoder_filename"] vae_config = server_config["vae_config"] boost = server_config["boost"] + reload_model = server_config["reload_model"] + transformer_quantization = server_config["transformer_quantization"] + transformer_filename = get_model_filename(transformer_type, transformer_quantization) + if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ): pass else: @@ -1357,20 +1370,12 @@ def abort_generation(state): else: return "", gr.Button(interactive= True) -def is_gen_location(state): - gen = get_gen_info(state) - gen_location = gen.get("location",None) - if gen_location == None: - return None - return state["image2video"] == gen_location - def refresh_gallery(state, msg): gen = get_gen_info(state) - if is_gen_location(state): - gen["last_msg"] = msg + gen["last_msg"] = msg file_list = gen.get("file_list", None) choice = gen.get("selected",0) in_progress = "in_progress" in gen @@ -1391,6 +1396,7 @@ def refresh_gallery(state, msg): start_img_uri = task.get('start_image_data_base64') start_img_uri = start_img_uri[0] if start_img_uri !=None else None end_img_uri = task.get('end_image_data_base64') + end_img_uri = end_img_uri[0] if end_img_uri !=None else None thumbnail_size = "100px" if start_img_uri: start_img_md = f'Start' @@ -1489,14 +1495,17 @@ def generate_video( embedded_guidance_scale, repeat_generation, multi_images_gen_type, - tea_cache, + tea_cache_setting, tea_cache_start_step_perc, - loras_choices, - loras_mult_choices, + activated_loras, + loras_multipliers, image_prompt_type, - image_source1, - image_source2, - image_source3, + image_start, + image_end, + video_prompt_type, + image_refs, + video_guide, + video_mask, max_frames, remove_background_image_ref, temporal_upsampling, @@ -1504,39 +1513,41 @@ def generate_video( RIFLEx_setting, slg_switch, slg_layers, - slg_start, - slg_end, + slg_start_perc, + slg_end_perc, cfg_star_switch, cfg_zero_step, state, - image2video + model_filename ): - global wan_model, offloadobj, reload_needed, last_model_type + global wan_model, offloadobj, reload_needed gen = get_gen_info(state) file_list = gen["file_list"] prompt_no = gen["prompt_no"] - - file_model_needed = model_needed(image2video) - # queue = gen.get("queue", []) - # with lock: - # queue_not_empty = len(queue) > 0 - # if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)): - if file_model_needed != model_filename or reload_needed: - del wan_model + + + # if wan_model == None: + # gr.Info("Unable to generate a Video while a new configuration is being applied.") + # return + + if reload_model !=3 : + while wan_model == None: + time.sleep(1) + + if model_filename != transformer_filename or reload_needed: + wan_model = None if offloadobj is not None: offloadobj.release() - del offloadobj + offloadobj = None gc.collect() - yield f"Loading model {get_model_name(file_model_needed)}..." - wan_model, offloadobj, trans = load_models(image2video) + yield f"Loading model {get_model_name(model_filename)}..." + wan_model, offloadobj, trans = load_models(model_filename) yield f"Model loaded" reload_needed= False - if wan_model == None: - gr.Info("Unable to generate a Video while a new configuration is being applied.") if attention_mode == "auto": attn = get_auto_attention() elif attention_mode in attention_modes_supported: @@ -1544,11 +1555,10 @@ def generate_video( else: gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.") return - - if not image2video: - width, height = resolution.split("x") - width, height = int(width), int(height) + width, height = resolution.split("x") + width, height = int(width), int(height) + resolution_reformated = str(height) + "*" + str(width) if slg_switch == 0: slg_layers = None @@ -1589,11 +1599,11 @@ def generate_video( except ValueError: return False list_mult_choices_nums = [] - if len(loras_mult_choices) > 0: - loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n") + if len(loras_multipliers) > 0: + loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] - loras_mult_choices = " ".join(loras_mult_choices_list) - list_mult_choices_str = loras_mult_choices.split(" ") + loras_multipliers = " ".join(loras_mult_choices_list) + list_mult_choices_str = loras_multipliers.split(" ") for i, mult in enumerate(list_mult_choices_str): mult = mult.strip() if "," in mult: @@ -1609,9 +1619,9 @@ def generate_video( if not is_float(mult): raise gr.Error(f"Lora Multiplier no {i+1} ({mult}) is invalid") list_mult_choices_nums.append(float(mult)) - if len(list_mult_choices_nums ) < len(loras_choices): - list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) - loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] + if len(list_mult_choices_nums ) < len(activated_loras): + list_mult_choices_nums += [1.0] * ( len(activated_loras) - len(list_mult_choices_nums ) ) + loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras] pinnedLora = profile !=5 #False # # # offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors @@ -1620,43 +1630,42 @@ def generate_video( raise gr.Error("Error while loading Loras: " + ", ".join(error_files)) seed = None if seed == -1 else seed # negative_prompt = "" # not applicable in the inference - + image2video = test_class_i2v(model_filename) enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1 # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 joint_pass = boost ==1 #and profile != 1 and profile != 3 # TeaCache - trans.enable_teacache = tea_cache > 0 + trans.enable_teacache = tea_cache_setting > 0 if trans.enable_teacache: - trans.teacache_multiplier = tea_cache + trans.teacache_multiplier = tea_cache_setting trans.rel_l1_thresh = 0 trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100) if image2video: - if '480p' in transformer_filename_i2v: + if '480p' in model_filename: # teacache_thresholds = [0.13, .19, 0.26] trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - elif '720p' in transformer_filename_i2v: + elif '720p' in model_filename: teacache_thresholds = [0.18, 0.2 , 0.3] trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] else: raise gr.Error("Teacache not supported for this model") else: - if '1.3B' in transformer_filename_t2v: + if '1.3B' in model_filename: # teacache_thresholds= [0.05, 0.07, 0.08] trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] - elif '14B' in transformer_filename_t2v: + elif '14B' in model_filename: # teacache_thresholds = [0.14, 0.15, 0.2] trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] else: raise gr.Error("Teacache not supported for this model") if "Vace" in model_filename: - resolution_reformated = str(height) + "*" + str(width) - src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2], - [image_source3], - [image_source1], + src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide], + [video_mask], + [image_refs], video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu", trim_video=max_frames) else: @@ -1721,10 +1730,10 @@ def generate_video( if image2video: samples = wan_model.generate( prompt, - image_source1, - image_source2 if image_source2 != None else None, + image_start, + image_end if image_end != None else None, frame_num=(video_length // 4)* 4 + 1, - max_area=MAX_AREA_CONFIGS[resolution], + max_area=MAX_AREA_CONFIGS[resolution_reformated], shift=flow_shift, sampling_steps=num_inference_steps, guide_scale=guidance_scale, @@ -1736,11 +1745,11 @@ def generate_video( VAE_tile_size = VAE_tile_size, joint_pass = joint_pass, slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, + slg_start = slg_start_perc/100, + slg_end = slg_end_perc/100, cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, - add_frames_for_end_image = not "Fun" in transformer_filename_i2v, + add_frames_for_end_image = not "Fun_InP" in model_filename, ) else: samples = wan_model.generate( @@ -1761,8 +1770,8 @@ def generate_video( VAE_tile_size = VAE_tile_size, joint_pass = joint_pass, slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, + slg_start = slg_start_perc/100, + slg_end = slg_end_perc/100, cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, ) @@ -1863,6 +1872,8 @@ def generate_video( h, w = sample.shape[-2:] h *= scale w *= scale + h = int(h) + w = int(w) new_frames =[] for i in range( sample.shape[1] ): frame = sample[:, i] @@ -1881,9 +1892,10 @@ def generate_video( nrow=1, normalize=True, value_range=(-1, 1)) - - configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step) + + inputs = get_function_arguments(generate_video, locals()) + inputs.pop("progress") + configs = prepare_inputs_dict("metadata", inputs) metadata_choice = server_config.get("metadata_choice","metadata") if metadata_choice == "json": @@ -1899,8 +1911,6 @@ def generate_video( file_list.append(video_path) state['update_gallery'] = True seed += 1 - - last_model_type = image2video if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) @@ -1913,36 +1923,6 @@ def prepare_generate_video(state): return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True) -def wait_tasks_done(state, progress=gr.Progress()): - - gen = get_gen_info(state) - gen_location = is_gen_location(state) - - last_msg = gen.get("last_msg", "") - if len(last_msg) > 0: - yield last_msg - - if gen_location == None or gen_location: - return gr.Text() - - - while True: - - msg = gen.get("last_msg", "") - if len(msg) > 0 and last_msg != msg: - yield msg - last_msg = msg - progress_args = gen.get("progress_args", None) - if progress_args != None: - progress(*progress_args) - - in_progress= gen.get("in_progress", False) - if not in_progress: - break - time.sleep(0.5) - - - def process_tasks(state, progress=gr.Progress()): gen = get_gen_info(state) queue = gen.get("queue", []) @@ -1950,7 +1930,6 @@ def process_tasks(state, progress=gr.Progress()): if len(queue) == 0: return gen = get_gen_info(state) - gen["location"] = state["image2video"] clear_file_list = server_config.get("clear_file_list", 0) file_list = gen.get("file_list", []) if clear_file_list > 0: @@ -2108,7 +2087,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ lset_name_filename = lset_name + ".lset" - full_lset_name_filename = os.path.join(get_lora_dir(state["image2video"]), lset_name_filename) + full_lset_name_filename = os.path.join(get_lora_dir(state["model_filename"]), lset_name_filename) with open(full_lset_name_filename, "w", encoding="utf-8") as writer: writer.write(json.dumps(lset, indent=4)) @@ -2125,7 +2104,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ def delete_lset(state, lset_name): loras_presets = state["loras_presets"] - lset_name_filename = os.path.join( get_lora_dir(state["image2video"]), sanitize_file_name(lset_name) + ".lset" ) + lset_name_filename = os.path.join( get_lora_dir(state["model_filename"]), sanitize_file_name(lset_name) + ".lset" ) if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False): if not os.path.isfile(lset_name_filename): raise gr.Error(f"Preset '{lset_name}' not found ") @@ -2146,8 +2125,8 @@ def delete_lset(state, lset_name): def refresh_lora_list(state, lset_name, loras_choices): loras_names = state["loras_names"] prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices] - image2video= state["image2video"] - loras, loras_names, loras_presets, _, _, _, _ = setup_loras(image2video, None, get_lora_dir(image2video), lora_preselected_preset, None) + model_filename= state["model_filename"] + loras, loras_names, loras_presets, _, _, _, _ = setup_loras(model_filename, None, get_lora_dir(model_filename), lora_preselected_preset, None) state["loras"] = loras state["loras_names"] = loras_names state["loras_presets"] = loras_presets @@ -2187,7 +2166,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m gr.Info("Please choose a preset in the list or create one") else: loras = state["loras"] - loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["image2video"], lset_name, loras) + loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["model_filename"], lset_name, loras) if len(error) > 0: gr.Info(error) else: @@ -2344,71 +2323,110 @@ def switch_advanced(state, new_advanced, lset_name): return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name) -def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): - +def prepare_inputs_dict(target, inputs ): + + state = inputs.pop("state") loras = state["loras"] - activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] + if "loras_choices" in inputs: + loras_choices = inputs.pop("loras_choices") + inputs.pop("model_filename", None) + activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] + inputs["activated_loras"] = activated_loras - ui_settings = { - "prompts": prompt, - "resolution": resolution, - "video_length": video_length, - "num_inference_steps": num_inference_steps, - "seed": seed, - "repeat_generation": repeat_generation, - "multi_images_gen_type": multi_images_gen_type, - "guidance_scale": guidance_scale, - "flow_shift": flow_shift, - "negative_prompt": negative_prompt, - "activated_loras": activated_loras, - "loras_multipliers": loras_mult_choices, - "tea_cache": tea_cache_setting, - "tea_cache_start_step_perc": tea_cache_start_step_perc, - "temporal_upsampling" : temporal_upsampling, - "spatial_upsampling" : spatial_upsampling, - "RIFLEx_setting": RIFLEx_setting, - "slg_switch": slg_switch, - "slg_layers": slg_layers, - "slg_start_perc": slg_start_perc, - "slg_end_perc": slg_end_perc, - "cfg_star_switch": cfg_star_switch, - "cfg_zero_step": cfg_zero_step - } + if target == "state": + return inputs - if i2v: - ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video" - ui_settings["image_prompt_type"] = image_prompt_type - else: - if "Vace" in transformer_filename_t2v or not image_metadata: - ui_settings["image_prompt_type"] = image_prompt_type - ui_settings["max_frames"] = max_frames - ui_settings["remove_background_image_ref"] = remove_background_image_ref - ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video" + unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] + for k in unsaved_params: + inputs.pop(k) - return ui_settings + model_filename = state["model_filename"] + inputs["type"] = "Wan2.1GP by DeepBeepMeep - " + get_model_name(model_filename) -def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): + if target == "settings": + return inputs + + if not any(k in model_filename for k in ["image2video", "Fun_InP"]): + inputs.pop("image_prompt_type") - if state.get("validate_success",0) != 1: - return - image2video = state["image2video"] - ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step) + if not "Vace" in model_filename: + unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"] + for k in unsaved_params: + inputs.pop(k) - defaults_filename = get_settings_file_name(image2video) + if target == "metadata": + inputs = {k: v for k,v in inputs.items() if v != None } - with open(defaults_filename, "w", encoding="utf-8") as f: - json.dump(ui_defaults, f, indent=4) + return inputs - gr.Info("New Default Settings saved") +def get_function_arguments(func, locals): + args_names = list(inspect.signature(func).parameters) + kwargs = typing.OrderedDict() + for k in args_names: + kwargs[k] = locals[k] + return kwargs + + +def save_inputs( + target, + prompt, + negative_prompt, + resolution, + video_length, + seed, + num_inference_steps, + guidance_scale, + flow_shift, + embedded_guidance_scale, + repeat_generation, + multi_images_gen_type, + tea_cache_setting, + tea_cache_start_step_perc, + loras_choices, + loras_multipliers, + image_prompt_type, + image_start, + image_end, + video_prompt_type, + image_refs, + video_guide, + video_mask, + max_frames, + remove_background_image_ref, + temporal_upsampling, + spatial_upsampling, + RIFLEx_setting, + slg_switch, + slg_layers, + slg_start_perc, + slg_end_perc, + cfg_star_switch, + cfg_zero_step, + state, +): + + + # if state.get("validate_success",0) != 1: + # return + model_filename = state["model_filename"] + inputs = get_function_arguments(save_inputs, locals()) + inputs.pop("target") + cleaned_inputs = prepare_inputs_dict(target, inputs) + if target == "settings": + defaults_filename = get_settings_file_name(model_filename) + + with open(defaults_filename, "w", encoding="utf-8") as f: + json.dump(cleaned_inputs, f, indent=4) + + gr.Info("New Default Settings saved") + elif target == "state": + state[get_model_type(model_filename)] = cleaned_inputs def download_loras(): from huggingface_hub import snapshot_download yield gr.Row(visible=True), "Please wait while the Loras are being downloaded", *[gr.Column(visible=False)] * 2 - lora_dir = get_lora_dir(True) + lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer) log_path = os.path.join(lora_dir, "log.txt") if not os.path.isfile(log_path): import shutil @@ -2433,42 +2451,130 @@ def download_loras(): writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}") return -def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio): +def refresh_image_prompt_type(state, image_prompt_type): if args.multiple_images: - return gr.Gallery(visible = (image_prompt_type_radio == 1) ) + return gr.Gallery(visible = "S" in image_prompt_type ), gr.Gallery(visible = "E" in image_prompt_type ) else: - return gr.Image(visible = (image_prompt_type_radio == 1) ) + return gr.Image(visible = "S" in image_prompt_type ), gr.Image(visible = "E" in image_prompt_type ) -def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio): - vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"] - return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio) +def refresh_video_prompt_type(state, video_prompt_type): + return gr.Gallery(visible = "I" in video_prompt_type), gr.Video(visible= "V" in video_prompt_type),gr.Video(visible= "M" in video_prompt_type ), gr.Text(visible= "V" in video_prompt_type) , gr.Checkbox(visible= "I" in video_prompt_type) -def check_refresh_input_type(state): - if not state["image2video"]: - model_file_name = state["image_input_type_model"] - model_file_needed= model_needed(False) - if model_file_name != model_file_needed: - state["image_input_type_model"] = model_file_needed - return gr.Text(value= str(time.time())) + +def handle_celll_selection(state, evt: gr.SelectData): + gen = get_gen_info(state) + queue = gen.get("queue", []) + + if evt.index is None: + return gr.update(), gr.update(), gr.update(visible=False) + row_index, col_index = evt.index + cell_value = None + if col_index in [6, 7, 8]: + if col_index == 6: cell_value = "↑" + elif col_index == 7: cell_value = "↓" + elif col_index == 8: cell_value = "✖" + if col_index == 6: + new_df_data = move_up(queue, [row_index]) + return new_df_data, gr.update(), gr.update(visible=False) + elif col_index == 7: + new_df_data = move_down(queue, [row_index]) + return new_df_data, gr.update(), gr.update(visible=False) + elif col_index == 8: + new_df_data = remove_task(queue, [row_index]) + gen["prompts_max"] = gen.get("prompts_max",0) - 1 + update_status(state) + return new_df_data, gr.update(), gr.update(visible=False) + start_img_col_idx = 4 + end_img_col_idx = 5 + image_data_to_show = None + if col_index == start_img_col_idx: + with lock: + row_index += 1 + if row_index < len(queue): + image_data_to_show = queue[row_index].get('start_image_data') + elif col_index == end_img_col_idx: + with lock: + row_index += 1 + if row_index < len(queue): + image_data_to_show = queue[row_index].get('end_image_data') + + if image_data_to_show: + return gr.update(), gr.update(value=image_data_to_show[0]), gr.update(visible=True) + else: + return gr.update(), gr.update(), gr.update(visible=False) + + +def change_model(state, model_choice): + model_filename = "" + for filename in model_list: + if get_model_type(filename) == model_choice: + model_filename = filename + break + if len(model_filename) == 0: + return + + state["model_filename"] = model_filename + header = generate_header(compile=compile, attention_mode=attention_mode) + return header + +def fill_inputs(state): + model_filename = state["model_filename"] + prefix = get_model_type(model_filename) + ui_defaults = state.get(prefix, None) + if ui_defaults == None: + ui_defaults = get_default_settings(model_filename) + + return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults) + +def preload_model(state): + global reload_needed, wan_model, offloadobj + if reload_model == 1: + model_filename = state["model_filename"] + if state["model_filename"] != transformer_filename: + wan_model = None + if offloadobj is not None: + offloadobj.release() + offloadobj = None + gc.collect() + yield f"Loading model {get_model_name(model_filename)}..." + wan_model, offloadobj, _ = load_models(model_filename) + yield f"Model loaded" + reload_needed= False + return return gr.Text() -def generate_video_tab(image2video=False): - filename = transformer_filename_i2v if image2video else transformer_filename_t2v - ui_defaults= get_default_settings(filename, image2video) +def unload_model_if_needed(state): + global reload_needed, wan_model, offloadobj + if reload_model == 3: + if wan_model != None: + wan_model = None + if offloadobj is not None: + offloadobj.release() + offloadobj = None + gc.collect() + reload_needed= True - state_dict = {} - state_dict["advanced"] = advanced - state_dict["loras_model"] = filename - state_dict["image_input_type_model"] = filename - state_dict["image2video"] = image2video - gen = dict() - gen["queue"] = [] - state_dict["gen"] = gen +def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None): + global inputs_names #, advanced - preset_to_load = lora_preselected_preset if use_image2video == image2video else "" + if update_form: + model_filename = state_dict["model_filename"] + advanced_ui = state_dict["advanced"] + else: + model_filename = transformer_filename + advanced_ui = advanced + ui_defaults= get_default_settings(model_filename) + state_dict = {} + state_dict["model_filename"] = model_filename + state_dict["advanced"] = advanced_ui + gen = dict() + gen["queue"] = [] + state_dict["gen"] = gen - loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None) + preset_to_load = lora_preselected_preset if lora_preset_model == model_filename else "" + + loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_filename, None, get_lora_dir(model_filename), preset_to_load, None) state_dict["loras"] = loras state_dict["loras_presets"] = loras_presets @@ -2479,18 +2585,19 @@ def generate_video_tab(image2video=False): launch_loras = [] launch_multis_str = "" - if len(default_lora_preset) > 0 and image2video == use_image2video: + if update_form: + pass + if len(default_lora_preset) > 0 and lora_preset_model == model_filename: launch_preset = default_lora_preset launch_prompt = default_lora_preset_prompt launch_loras = default_loras_choices launch_multis_str = default_loras_multis_str if len(launch_prompt) == 0: - launch_prompt = ui_defaults["prompts"] + launch_prompt = ui_defaults.get("prompt","") if len(launch_loras) == 0: - activated_loras = ui_defaults["activated_loras"] - launch_multis_str = ui_defaults["loras_multipliers"] - + launch_multis_str = ui_defaults.get("loras_multipliers","") + activated_loras = ui_defaults.get("activated_loras",[]) if len(activated_loras) > 0: lora_filenames = [os.path.basename(lora_path) for lora_path in loras] activated_indices = [] @@ -2502,24 +2609,20 @@ def generate_video_tab(image2video=False): print(f"Warning: Lora file {lora_file} from config not found in loras directory") launch_loras = activated_indices - - header = gr.Markdown(generate_header(model_filename, compile, attention_mode)) with gr.Row(): with gr.Column(): with gr.Column(visible=False, elem_id="image-modal-container") as modal_container: with gr.Row(elem_id="image-modal-close-button-row"): close_modal_button = gr.Button("❌", size="sm") modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False) - progress_update_trigger = gr.Textbox(value="0", visible=False, label="_progress_trigger") - gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger") - with gr.Row(visible= len(loras)>0) as presets_column: - lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")] + with gr.Row(visible= True): #len(loras)>0) as presets_column: + lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced_ui), "")] with gr.Column(scale=6): lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=launch_preset) with gr.Column(scale=1): with gr.Row(height=17): apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1) - refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced) + refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced_ui or not only_allow_edit_in_advanced) save_lset_prompt_drop= gr.Dropdown( choices=[ ("Save Prompt Comments Only", 0), @@ -2529,62 +2632,50 @@ def generate_video_tab(image2video=False): with gr.Row(height=17, visible=False) as refresh2_row: refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1) - with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows: + with gr.Row(height=17, visible=advanced_ui or not only_allow_edit_in_advanced) as preset_buttons_rows: confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False) confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False) save_lset_btn = gr.Button("Save", size="sm", min_width= 1) delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) - state = gr.State(state_dict) - vace_model = "Vace" in filename and not image2video + if not update_form: + state = gr.State(state_dict) trigger_refresh_input_type = gr.Text(interactive= False, visible= False) - with gr.Column(visible= image2video or vace_model) as image_prompt_column: - if image2video: - image_source3 = gr.Video(label= "Placeholder", visible= image2video and False) + with gr.Column(visible= "image2video" in model_filename or "Fun_InP" in model_filename ) as image_prompt_column: + image_prompt_type_value= ui_defaults.get("image_prompt_type","S") + image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, scale= 3) - image_prompt_type= ui_defaults.get("image_prompt_type",0) - image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3) - - if args.multiple_images: - image_source1 = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True) - else: - image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil") - - if args.multiple_images: - image_source2 = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1) - else: - image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1) - - - image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2]) - max_frames = gr.Slider(1, 100,step=1, visible = False) - remove_background_image_ref = gr.Text(visible = False) + if args.multiple_images: + image_start = gr.Gallery( + label="Images as starting points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) else: - image_prompt_type= ui_defaults.get("image_prompt_type","I") - if not isinstance(image_prompt_type, str): - image_prompt_type ="I" - image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model) - image_source1 = gr.Gallery( - label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type ) + image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type ) - with gr.Row(): - max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 ) - remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 ) + if args.multiple_images: + image_end = gr.Gallery( + label="Images as ending points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + else: + image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type ) + with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column: + video_prompt_type_value= ui_defaults.get("video_prompt_type","I") + video_prompt_type = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3) + image_refs = gr.Gallery( + label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) ) - - gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref]) + video_guide = gr.Video(label= "Reference Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) ) + with gr.Row(): + max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Ref. Video (0 = as many as possible)", visible= "V" in video_prompt_type_value, scale = 2 ) + remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in video_prompt_type_value, scale =1 ) + + video_mask = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None) ) - advanced_prompt = advanced + advanced_prompt = advanced_ui prompt_vars=[] if advanced_prompt: @@ -2615,14 +2706,14 @@ def generate_video_tab(image2video=False): wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) with gr.Row(): - if image2video: + if "image2video" in model_filename or "Fun_InP" in model_filename: resolution = gr.Dropdown( choices=[ # 720p ("720p", "1280x720"), ("480p", "832x480"), ], - value=ui_defaults["resolution"], + value=ui_defaults.get("resolution","480p"), label="Resolution (video will have the same height / width ratio than the original image)" ) else: @@ -2642,33 +2733,33 @@ def generate_video_tab(image2video=False): # ("624x832 (3:4, 540p)", "624x832"), # ("720x720 (1:1, 540p)", "720x720"), ], - value=ui_defaults["resolution"], + value=ui_defaults.get("resolution","832x480"), label="Resolution" ) with gr.Row(): with gr.Column(): - video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)") + video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)") with gr.Column(): - num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps") - show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced) - with gr.Row(visible=advanced) as advanced_row: + num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps") + show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui) + with gr.Row(visible=advanced_ui) as advanced_row: with gr.Column(): seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)") with gr.Row(): - repeat_generation = gr.Slider(1, 25.0, value=ui_defaults["repeat_generation"], step=1, label="Default Number of Generated Videos per Prompt") - multi_images_gen_type = gr.Dropdown( value=ui_defaults["multi_images_gen_type"], + repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt") + multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0), choices=[ ("Generate every combination of images and texts", 0), ("Match images and text prompts", 1), ], visible= args.multiple_images, label= "Multiple Images as Texts Prompts" ) with gr.Row(): - guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults["guidance_scale"], step=0.5, label="Guidance Scale", visible=True) + guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False) - flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults["flow_shift"], step=0.1, label="Shift Scale") + flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") with gr.Row(): - negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults["negative_prompt"]) - with gr.Column(visible = len(loras)>0) as loras_column: + negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") ) + with gr.Column(visible = True): #as loras_column: gr.Markdown("Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.") loras_choices = gr.Dropdown( choices=[ @@ -2678,7 +2769,7 @@ def generate_video_tab(image2video=False): multiselect= True, label="Activated Loras" ) - loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str) + loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str) with gr.Row(): gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)") with gr.Row(): @@ -2691,16 +2782,16 @@ def generate_video_tab(image2video=False): ("around x2.25 speed up", 2.25), ("around x2.5 speed up", 2.5), ], - value=float(ui_defaults["tea_cache"]), + value=float(ui_defaults.get("tea_cache_setting",0)), visible=True, label="Tea Cache Global Acceleration" ) - tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation") + tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("tea_cache_start_step_perc",0), step=1, label="Tea Cache starting moment in % of generation") with gr.Row(): gr.Markdown("Upsampling - postprocessing that may improve fluidity and the size of the video") with gr.Row(): - temporal_upsampling_choice = gr.Dropdown( + temporal_upsampling = gr.Dropdown( choices=[ ("Disabled", ""), ("Rife x2 (32 frames/s)", "rife2"), @@ -2711,7 +2802,7 @@ def generate_video_tab(image2video=False): scale = 1, label="Temporal Upsampling" ) - spatial_upsampling_choice = gr.Dropdown( + spatial_upsampling = gr.Dropdown( choices=[ ("Disabled", ""), ("Lanczos x1.5", "lanczos1.5"), @@ -2730,7 +2821,7 @@ def generate_video_tab(image2video=False): ("Always ON", 1), ("Always OFF", 2), ], - value=ui_defaults["RIFLEx_setting"], + value=ui_defaults.get("RIFLEx_setting",0), label="RIFLEx positional embedding to generate long video" ) with gr.Row(): @@ -2741,7 +2832,7 @@ def generate_video_tab(image2video=False): ("OFF", 0), ("ON", 1), ], - value=ui_defaults["slg_switch"], + value=ui_defaults.get("slg_switch",0), visible=True, scale = 1, label="Skip Layer guidance" @@ -2750,14 +2841,14 @@ def generate_video_tab(image2video=False): choices=[ (str(i), i ) for i in range(40) ], - value=ui_defaults["slg_layers"], + value=ui_defaults.get("slg_layers", ["9"]), multiselect= True, label="Skip Layers", scale= 3 ) with gr.Row(): - slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start") - slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end") + slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start") + slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end") with gr.Row(): gr.Markdown("Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt") @@ -2777,239 +2868,142 @@ def generate_video_tab(image2video=False): with gr.Row(): save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config) + + if not update_form: + with gr.Column(): + gen_status = gr.Text(interactive= False, label = "Status") + output = gr.Gallery( label="Generated videos", show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) + generate_btn = gr.Button("Generate") + add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False) + + with gr.Column(visible= False) as current_gen_column: + with gr.Row(): + gen_info = gr.HTML(visible=False, min_height=1) + with gr.Row(): + onemore_btn = gr.Button("One More Sample Please !") + abort_btn = gr.Button("Abort") + + queue_df = gr.DataFrame( + headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], + datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], + column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], + interactive=False, + col_count=(9, "fixed"), + wrap=True, + value=[], + line_breaks= True, + visible= False, + elem_id="queue_df" + ) + + extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, + prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column, + if update_form: + locals_dict = locals() + gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs + return gen_inputs + else: + target_state = gr.Text(value = "state", interactive= False, visible= False) + target_settings = gr.Text(value = "settings", interactive= False, visible= False) + + image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end]) + video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref]) show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) - with gr.Column(): - gen_status = gr.Text(interactive= False) - full_sync = gr.Text(interactive= False, visible= False) - light_sync = gr.Text(interactive= False, visible= False) - - gen_progress_html = gr.HTML( - label="Status", - value="Idle", - elem_id="generation_progress_bar_container", visible= False + queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) + save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) + confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + save_lset, inputs=[state, lset_name, loras_choices, loras_multipliers, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) + delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) + confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) + cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) + apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt]).then( + fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] ) - output = gr.Gallery( - label="Generated videos", show_label=False, elem_id="gallery" - , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) - generate_btn = gr.Button("Generate") - add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False) - - with gr.Column(visible= False) as current_gen_column: - with gr.Row(): - gen_info = gr.HTML(visible=False, min_height=1) - with gr.Row(): - onemore_btn = gr.Button("One More Sample Please !") - abort_btn = gr.Button("Abort") - - queue_df = gr.DataFrame( - headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], - datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], - column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], - interactive=False, - col_count=(9, "fixed"), - wrap=True, - value=[], - line_breaks= True, - visible= False, - # every=1, - elem_id="queue_df" - ) - # queue_df = gr.HTML("", - # visible= False, - # elem_id="queue_df" - # ) - - def handle_selection(state, evt: gr.SelectData): - gen = get_gen_info(state) - queue = gen.get("queue", []) - - if evt.index is None: - return gr.update(), gr.update(), gr.update(visible=False) - row_index, col_index = evt.index - cell_value = None - if col_index in [6, 7, 8]: - if col_index == 6: cell_value = "↑" - elif col_index == 7: cell_value = "↓" - elif col_index == 8: cell_value = "✖" - if col_index == 6: - new_df_data = move_up(queue, [row_index]) - return new_df_data, gr.update(), gr.update(visible=False) - elif col_index == 7: - new_df_data = move_down(queue, [row_index]) - return new_df_data, gr.update(), gr.update(visible=False) - elif col_index == 8: - new_df_data = remove_task(queue, [row_index]) - gen["prompts_max"] = gen.get("prompts_max",0) - 1 - update_status(state) - return new_df_data, gr.update(), gr.update(visible=False) - start_img_col_idx = 4 - end_img_col_idx = 5 - image_data_to_show = None - if col_index == start_img_col_idx: - with lock: - if row_index < len(queue): - image_data_to_show = queue[row_index].get('start_image_data') - elif col_index == end_img_col_idx: - with lock: - if row_index < len(queue): - image_data_to_show = queue[row_index].get('end_image_data') - - if image_data_to_show: - return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True) - else: - return gr.update(), gr.update(), gr.update(visible=False) - selected_indices = gr.State([]) - queue_df.select( - fn=handle_selection, - inputs=state, - outputs=[queue_df, modal_image_display, modal_container], - ) - # gallery_update_trigger.change( - # fn=refresh_gallery_on_trigger, - # inputs=[state], - # outputs=[output] - # ) - # queue_df.change( - # fn=refresh_gallery, - # inputs=[state], - # outputs=[gallery_update_trigger] - # ).then( - # fn=refresh_progress, - # inputs=None, - # outputs=[progress_update_trigger] - # ) - progress_update_trigger.change( - fn=update_generation_status, - inputs=[progress_update_trigger], - outputs=[gen_progress_html], - show_progress="hidden" - ) - save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( - save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, - loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers, - slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = []) - save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) - confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( - save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) - delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) - apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_mult_choices, prompt]).then( - fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] - ) - refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) - refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) - output.select(select_video, state, None ) - - + refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) + refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) + output.select(select_video, state, None ) + + gen_status.change(refresh_gallery, + inputs = [state, gen_status], + outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]) - gen_status.change(refresh_gallery, - inputs = [state, gen_status], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]) - - full_sync.change(fn= check_refresh_input_type, - inputs= [state], - outputs= [trigger_refresh_input_type] - ).then(fn=refresh_gallery, - inputs = [state, gen_status], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] - ).then(fn=wait_tasks_done, - inputs= [state], - outputs =[gen_status], - ).then(finalize_generation, - inputs= [state], - outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] - ) - light_sync.change(fn= check_refresh_input_type, - inputs= [state], - outputs= [trigger_refresh_input_type] - ).then(fn=refresh_gallery, - inputs = [state, gen_status], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] - ) + abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] ) + onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state]) - abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] ) - onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state]) + inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1] + locals_dict = locals() + gen_inputs = [locals_dict[k] for k in inputs_names] + [state] + save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + save_inputs, inputs =[target_settings] + gen_inputs, outputs = []) - gen_inputs=[ - prompt, - negative_prompt, - resolution, - video_length, - seed, - num_inference_steps, - guidance_scale, - flow_shift, - embedded_guidance_scale, - repeat_generation, - multi_images_gen_type, - tea_cache_setting, - tea_cache_start_step_perc, - loras_choices, - loras_mult_choices, - image_prompt_type_radio, - image_source1, - image_source2, - image_source3, - max_frames, - remove_background_image_ref, - temporal_upsampling_choice, - spatial_upsampling_choice, - RIFLEx_setting, - slg_switch, - slg_layers, - slg_start_perc, - slg_end_perc, - cfg_star_switch, - cfg_zero_step, - state, - gr.State(image2video) - ] + model_choice.change(fn=validate_wizard_prompt, + inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then(fn= change_model, + inputs=[state, model_choice], + outputs= [header] + ).then(fn= fill_inputs, + inputs=[state], + outputs=gen_inputs + extra_inputs + ).then(fn= preload_model, + inputs=[state], + outputs=[gen_status]) - generate_btn.click(fn=validate_wizard_prompt, - inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] - ).then(fn=process_prompt_and_add_tasks, - inputs = gen_inputs, - outputs= queue_df - ).then(fn=prepare_generate_video, - inputs= [state], - outputs= [generate_btn, add_to_queue_btn, current_gen_column], - ).then(fn=process_tasks, - inputs= [state], - outputs= [gen_status], - ).then(finalize_generation, - inputs= [state], - outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] - ) + generate_btn.click(fn=validate_wizard_prompt, + inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then(fn=process_prompt_and_add_tasks, + inputs = [state, model_choice], + outputs= queue_df + ).then(fn=prepare_generate_video, + inputs= [state], + outputs= [generate_btn, add_to_queue_btn, current_gen_column], + ).then(fn=process_tasks, + inputs= [state], + outputs= [gen_status], + ).then(finalize_generation, + inputs= [state], + outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] + ).then(unload_model_if_needed, + inputs= [state], + outputs= [] + ) - add_to_queue_btn.click(fn=validate_wizard_prompt, - inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] - ).then( - fn=process_prompt_and_add_tasks, - inputs = gen_inputs, - outputs=queue_df - ).then( - fn=update_status, - inputs = [state], - ) + add_to_queue_btn.click(fn=validate_wizard_prompt, + inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then(fn=process_prompt_and_add_tasks, + inputs = [state, model_choice], + outputs=queue_df + ).then( + fn=update_status, + inputs = [state], + ) + close_modal_button.click( + lambda: gr.update(visible=False), + inputs=[], + outputs=[modal_container] + ) - close_modal_button.click( - lambda: gr.update(visible=False), - inputs=[], - outputs=[modal_container] - ) - return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state + return loras_choices, lset_name, state -def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state): +def generate_download_tab(lset_name,loras_choices, state): with gr.Row(): with gr.Row(scale =2): - gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).") + gr.Markdown("WanGP's Lora Festival ! Press the following button to download i2v Remade_AI Loras collection (and bonuses Loras).") with gr.Row(scale =1): download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1) with gr.Row(scale =1): @@ -3017,7 +3011,7 @@ def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, with gr.Row() as download_status_row: download_status = gr.Markdown() - download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) + download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) def generate_configuration_tab(): @@ -3025,37 +3019,50 @@ def generate_configuration_tab(): state = gr.State(state_dict) gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") with gr.Column(): - index = transformer_choices_t2v.index(transformer_filename_t2v) + index = transformer_choices.index(transformer_filename) index = 0 if index ==0 else index - transformer_t2v_choice = gr.Dropdown( + + model_list = [] + for model_type in model_types: + choice = get_model_filename(model_type, transformer_quantization) + model_list.append(choice) + dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] + transformer_type_choice = gr.Dropdown( + choices= dropdown_choices, + value= get_model_type(transformer_filename), + label= "Default Wan Transformer Model", + scale= 2 + ) + + # transformer_choice = gr.Dropdown( + # choices=[ + # ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0), + # ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1), + # ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2), + # ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3), + # ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4), + # ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5), + # ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6), + # ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7), + # ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8), + # ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9), + # ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10), + # ], + # value= index, + # label="Transformer model for Image to Video", + # interactive= not lock_ui_transformer, + # visible = True, + # ) + + quantization_choice = gr.Dropdown( choices=[ - ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0), - ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1), - ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2), - ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3), + ("Int8 Quantization (recommended)", "int8"), + ("BF16 (no quantization)", "bf16"), ], - value= index, - label="Transformer model for Text to Video", - interactive= not lock_ui_transformer, - visible=True - ) - index = transformer_choices_i2v.index(transformer_filename_i2v) - index = 0 if index ==0 else index - transformer_i2v_choice = gr.Dropdown( - choices=[ - ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0), - ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1), - ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2), - ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 3), - ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 4), - ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 5), - ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 6), - ], - value= index, - label="Transformer model for Image to Video", - interactive= not lock_ui_transformer, - visible = True, - ) + value= transformer_quantization, + label="Wan Transformer Model Quantization (if available)", + ) + index = text_encoder_choices.index(text_encoder_filename) index = 0 if index ==0 else index text_encoder_choice = gr.Dropdown( @@ -3149,11 +3156,12 @@ def generate_configuration_tab(): ) reload_choice = gr.Dropdown( choices=[ - ("When changing tabs", 1), - ("When pressing Generate", 2), + ("Load Model When Changing Model", 1), + ("Load Model When Pressing Generate", 2), + ("Load Model When Pressing Generate and Unload Model when Finished", 3), ], value=server_config.get("reload_model",2), - label="Reload model" + label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)" ) clear_file_list_choice = gr.Dropdown( @@ -3176,8 +3184,7 @@ def generate_configuration_tab(): fn=apply_changes, inputs=[ state, - transformer_t2v_choice, - transformer_i2v_choice, + transformer_type_choice, text_encoder_choice, save_path_choice, attention_choice, @@ -3185,7 +3192,7 @@ def generate_configuration_tab(): profile_choice, vae_config_choice, metadata_choice, - default_ui_choice, + quantization_choice, boost_choice, clear_file_list_choice, reload_choice, @@ -3194,7 +3201,7 @@ def generate_configuration_tab(): ) def generate_about_tab(): - gr.Markdown("

Wan2.1GP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") + gr.Markdown("

WanGP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") gr.Markdown("Original Wan 2.1 Model by Alibaba (GitHub)") gr.Markdown("Many thanks to:") gr.Markdown("- Alibaba Wan team for the best open source video generator") @@ -3202,114 +3209,19 @@ def generate_about_tab(): gr.Markdown("- Tophness : created multi tabs and queuing frameworks") gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") gr.Markdown("- Remade_AI : for creating their awesome Loras collection") + +def generate_info_tab(): + gr.Markdown("Welcome to WanGP a super fast and low VRAM AI Video Generator !") - -def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): - t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) - i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) - - new_t2v = evt.index == 0 - new_i2v = evt.index == 1 - i2v_light_sync = gr.Text() - t2v_light_sync = gr.Text() - i2v_full_sync = gr.Text() - t2v_full_sync = gr.Text() - - last_tab_was_image2video =global_state.get("last_tab_was_image2video", None) - if last_tab_was_image2video == None or last_tab_was_image2video: - gen = i2v_state["gen"] - t2v_state["gen"] = gen - else: - gen = t2v_state["gen"] - i2v_state["gen"] = gen + gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :") + gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") + gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") + gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") + gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") + gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") - if new_t2v or new_i2v: - if last_tab_was_image2video != None and new_t2v != new_i2v: - gen_location = gen.get("location", None) - if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) : - if new_i2v: - i2v_full_sync = gr.Text(str(time.time())) - else: - t2v_full_sync = gr.Text(str(time.time())) - else: - if new_i2v: - i2v_light_sync = gr.Text(str(time.time())) - else: - t2v_light_sync = gr.Text(str(time.time())) - global_state["last_tab_was_image2video"] = new_i2v - - if(server_config.get("reload_model",2) == 1): - queue = gen.get("queue", []) - - queue_empty = len(queue) == 0 - if queue_empty: - global wan_model, offloadobj - if wan_model is not None: - if offloadobj is not None: - offloadobj.release() - offloadobj = None - wan_model = None - gc.collect() - torch.cuda.empty_cache() - wan_model, offloadobj, trans = load_models(new_i2v) - del trans - - if new_t2v or new_i2v: - state = i2v_state if new_i2v else t2v_state - lora_model_filename = state["loras_model"] - model_filename = model_needed(new_i2v) - if ("1.3B" in model_filename and not "1.3B" in lora_model_filename or "14B" in model_filename and not "14B" in lora_model_filename): - lora_dir = get_lora_dir(new_i2v) - loras, loras_names, loras_presets, _, _, _, _ = setup_loras(new_i2v, None, lora_dir, lora_preselected_preset, None) - state["loras"] = loras - state["loras_names"] = loras_names - state["loras_presets"] = loras_presets - state["loras_model"] = model_filename - - advanced = state["advanced"] - new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)] - lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(advanced), "")] - visible = len(loras_names)>0 - if new_t2v: - return [ - gr.Column(visible= visible), - gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]), - gr.Column(visible= visible), - gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible), - t2v_header, - t2v_light_sync, - t2v_full_sync, - gr.Column(), - gr.Dropdown(), - gr.Column(), - gr.Dropdown(), - gr.Markdown(), - gr.Text(), - gr.Text(), - ] - else: - return [ - gr.Column(), - gr.Dropdown(), - gr.Column(), - gr.Dropdown(), - gr.Markdown(), - gr.Text(), - gr.Text(), - gr.Text(), - gr.Column(visible= visible), - gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]), - gr.Column(visible= visible), - gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible), - i2v_header, - i2v_light_sync, - i2v_full_sync, - ] - - return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, t2v_light_sync, t2v_full_sync, - gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header, i2v_light_sync, i2v_full_sync] def create_demo(): @@ -3539,40 +3451,41 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo: - gr.Markdown("

Wan 2.1GP v4.0 by DeepBeepMeep (Updates)

") - gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") - - with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): - gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :") - gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") - gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") - gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") - gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") - gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") - global_dict = {} - global_dict["last_tab_was_image2video"] = use_image2video - global_state = gr.State(global_dict) + gr.Markdown("

WanGP v4.0 by DeepBeepMeep ") # (Updates)

") + global model_list - with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs: - with gr.Tab("Text To Video", id="t2v") as t2v_tab: - t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False) - with gr.Tab("Image To Video", id="i2v") as i2v_tab: - i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True) + with gr.Tabs(selected="video_gen", ) as main_tabs: + with gr.Tab("Video Generator", id="video_gen") as t2v_tab: + with gr.Row(): + header = gr.Markdown(generate_header(compile, attention_mode), visible= True) + with gr.Row(): + gr.Markdown("
") + + model_list = [] + for model_type in model_types: + choice = get_model_filename(model_type, transformer_quantization) + model_list.append(choice) + dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] + model_choice = gr.Dropdown( + choices= dropdown_choices, + value= get_model_type(transformer_filename), + show_label= False, + scale= 2 + ) + gr.Markdown("
") + with gr.Row(): + + loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) + with gr.Tab("Informations"): + generate_info_tab() if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: - generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) + generate_download_tab(lset_name, loras_choices, state) with gr.Tab("Configuration"): generate_configuration_tab() with gr.Tab("About"): generate_about_tab() - main_tabs.select( - fn=on_tab_select, - inputs=[global_state, t2v_state, i2v_state], - outputs=[ - t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, - i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync - ] - ) + return demo if __name__ == "__main__":