diff --git a/wgp.py b/wgp.py index c81ab9a..db02e07 100644 --- a/wgp.py +++ b/wgp.py @@ -441,6 +441,12 @@ def _parse_args(): help="Prevent modifying the configuration from the web interface" ) + parser.add_argument( + "--lock-model", + action="store_true", + help="Prevent switch models" + ) + parser.add_argument( "--preload", type=str, @@ -700,11 +706,15 @@ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] transformer_choices = transformer_choices_t2v + transformer_choices_i2v text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] -server_config_filename = "gradio_config.json" +server_config_filename = "wgp_config.json" + +if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"): + import shutil + shutil.move("gradio_config.json", server_config_filename) if not Path(server_config_filename).is_file(): server_config = {"attention_mode" : "auto", - "transformer_type": "t2v", + "transformer_types": [], "transformer_quantization": "int8", "text_encoder_filename" : text_encoder_choices[1], "save_path": os.path.join(os.getcwd(), "gradio_outputs"), @@ -826,7 +836,8 @@ def get_default_settings(filename): ui_defaults["num_inference_steps"] = default_number_steps return ui_defaults -transformer_type = server_config.get("transformer_type", "t2v") +transformer_types = server_config.get("transformer_types", []) +transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] transformer_quantization =server_config.get("transformer_quantization", "int8") transformer_filename = get_model_filename(transformer_type, transformer_quantization) text_encoder_filename = server_config["text_encoder_filename"] @@ -1213,22 +1224,25 @@ def get_model_name(model_filename): # return header -def generate_header(compile, attention_mode): +def generate_header(model_filename, compile, attention_mode): - header = "
Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) + header = "
Attention mode " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) if attention_mode not in attention_modes_installed: header += " -NOT INSTALLED-" elif attention_mode not in attention_modes_supported: header += " -NOT SUPPORTED-" + header += "" if compile: - header += ", pytorch compilation ON" + header += ", Pytorch compilation ON" + if "int8" in model_filename: + header += ", Quantization Int8" header += "
" return header def apply_changes( state, - transformer_type_choice, + transformer_types_choices, text_encoder_choice, save_path_choice, attention_choice, @@ -1239,16 +1253,15 @@ def apply_changes( state, quantization_choice, boost_choice = 1, clear_file_list = 0, - reload_choice = 1 + reload_choice = 1, ): if args.lock_config: return if gen_in_progress: - yield "
Unable to change config when a generation is in progress
" - return + return "
Unable to change config when a generation is in progress
" global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = {"attention_mode" : attention_choice, - "transformer_type": transformer_type_choice, + "transformer_types": transformer_types_choices, "text_encoder_filename" : text_encoder_choices[text_encoder_choice], "save_path" : save_path_choice, "compile" : compile_choice, @@ -1281,7 +1294,7 @@ def apply_changes( state, if v != v_old: changes.append(k) - global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type + global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types attention_mode = server_config["attention_mode"] profile = server_config["profile"] compile = server_config["compile"] @@ -1290,15 +1303,19 @@ def apply_changes( state, boost = server_config["boost"] reload_model = server_config["reload_model"] transformer_quantization = server_config["transformer_quantization"] - transformer_filename = get_model_filename(transformer_type, transformer_quantization) - - if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ): - pass + transformer_types = server_config["transformer_types"] + transformer_type = get_model_type(transformer_filename) + if not transformer_type in transformer_types: + transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] + transformer_filename = get_model_filename(transformer_type, transformer_quantization) + if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ): + model_choice = gr.Dropdown() else: reload_needed = True + model_choice = generate_dropdown_model_list() - - yield "
The new configuration has been succesfully applied
" + header = generate_header(transformer_filename, compile=compile, attention_mode= attention_mode) + return "
The new configuration has been succesfully applied
", header, model_choice @@ -2505,16 +2522,11 @@ def handle_celll_selection(state, evt: gr.SelectData): def change_model(state, model_choice): - model_filename = "" - for filename in model_list: - if get_model_type(filename) == model_choice: - model_filename = filename - break - if len(model_filename) == 0: + if model_choice == None: return - + model_filename = get_model_filename(model_choice, transformer_quantization) state["model_filename"] = model_filename - header = generate_header(compile=compile, attention_mode=attention_mode) + header = generate_header(model_filename, compile=compile, attention_mode=attention_mode) return header def fill_inputs(state): @@ -3014,53 +3026,32 @@ def generate_download_tab(lset_name,loras_choices, state): download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) -def generate_configuration_tab(): +def generate_configuration_tab(header, model_choice): state_dict = {} state = gr.State(state_dict) gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") with gr.Column(): - index = transformer_choices.index(transformer_filename) - index = 0 if index ==0 else index - model_list = [] + for model_type in model_types: choice = get_model_filename(model_type, transformer_quantization) model_list.append(choice) dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] - transformer_type_choice = gr.Dropdown( + transformer_types_choices = gr.Dropdown( choices= dropdown_choices, - value= get_model_type(transformer_filename), - label= "Default Wan Transformer Model", - scale= 2 + value= transformer_types, + label= "Selectable Wan Transformer Models (keep empty to get All of them)", + scale= 2, + multiselect= True ) - # transformer_choice = gr.Dropdown( - # choices=[ - # ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0), - # ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1), - # ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2), - # ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3), - # ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4), - # ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5), - # ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6), - # ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7), - # ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8), - # ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9), - # ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10), - # ], - # value= index, - # label="Transformer model for Image to Video", - # interactive= not lock_ui_transformer, - # visible = True, - # ) - quantization_choice = gr.Dropdown( choices=[ ("Int8 Quantization (recommended)", "int8"), ("BF16 (no quantization)", "bf16"), ], value= transformer_quantization, - label="Wan Transformer Model Quantization (if available)", + label="Wan Transformer Model Quantization Type (if available)", ) index = text_encoder_choices.index(text_encoder_filename) @@ -3137,14 +3128,14 @@ def generate_configuration_tab(): value= profile, label="Profile (for power users only, not needed to change it)" ) - default_ui_choice = gr.Dropdown( - choices=[ - ("Text to Video", "t2v"), - ("Image to Video", "i2v"), - ], - value= default_ui, - label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ", - ) + # default_ui_choice = gr.Dropdown( + # choices=[ + # ("Text to Video", "t2v"), + # ("Image to Video", "i2v"), + # ], + # value= default_ui, + # label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ", + # ) metadata_choice = gr.Dropdown( choices=[ ("Export JSON files", "json"), @@ -3184,7 +3175,7 @@ def generate_configuration_tab(): fn=apply_changes, inputs=[ state, - transformer_type_choice, + transformer_types_choices, text_encoder_choice, save_path_choice, attention_choice, @@ -3197,7 +3188,7 @@ def generate_configuration_tab(): clear_file_list_choice, reload_choice, ], - outputs= msg + outputs= [msg , header, model_choice] ) def generate_about_tab(): @@ -3221,6 +3212,22 @@ def generate_info_tab(): gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") +def generate_dropdown_model_list(): + dropdown_types= transformer_types if len(transformer_types) > 0 else model_types + current_model_type = get_model_type(transformer_filename) + if current_model_type not in dropdown_types: + dropdown_types.append(current_model_type) + model_list = [] + for model_type in dropdown_types: + choice = get_model_filename(model_type, transformer_quantization) + model_list.append(choice) + dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] + return gr.Dropdown( + choices= dropdown_choices, + value= current_model_type, + show_label= False, + scale= 2 + ) @@ -3457,22 +3464,15 @@ def create_demo(): with gr.Tabs(selected="video_gen", ) as main_tabs: with gr.Tab("Video Generator", id="video_gen") as t2v_tab: with gr.Row(): - header = gr.Markdown(generate_header(compile, attention_mode), visible= True) + if args.lock_model: + gr.Markdown("

" + get_model_name(transformer_filename) + "

") + model_choice = gr.Dropdown(visible=False, value= get_model_type(transformer_filename)) + else: + gr.Markdown("
") + model_choice = generate_dropdown_model_list() + gr.Markdown("
") with gr.Row(): - gr.Markdown("
") - - model_list = [] - for model_type in model_types: - choice = get_model_filename(model_type, transformer_quantization) - model_list.append(choice) - dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] - model_choice = gr.Dropdown( - choices= dropdown_choices, - value= get_model_type(transformer_filename), - show_label= False, - scale= 2 - ) - gr.Markdown("
") + header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True) with gr.Row(): loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) @@ -3482,7 +3482,7 @@ def create_demo(): with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) with gr.Tab("Configuration"): - generate_configuration_tab() + generate_configuration_tab(header, model_choice) with gr.Tab("About"): generate_about_tab()