From 347ab55d4aea645997056211634b3c8663dfa120 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 24 Mar 2025 22:23:51 +1100 Subject: [PATCH] add config option to set whether to reload model upon changing tabs or pressing generate --- gradio_server.py | 56 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 2d551cb..2c71042 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -491,7 +491,8 @@ if not Path(server_config_filename).is_file(): "default_ui": "t2v", "boost" : 1, "vae_config": 0, - "profile" : profile_type.LowRAM_LowVRAM } + "profile" : profile_type.LowRAM_LowVRAM, + "reload_model": 1 } with open(server_config_filename, "w", encoding="utf-8") as writer: writer.write(json.dumps(server_config)) @@ -915,7 +916,8 @@ def apply_changes( state, vae_config_choice, metadata_choice, default_ui_choice ="t2v", - boost_choice = 1 + boost_choice = 1, + reload_choice = 1 ): if args.lock_config: return @@ -934,6 +936,7 @@ def apply_changes( state, "metadata_choice": metadata_choice, "default_ui" : default_ui_choice, "boost" : boost_choice, + "reload_model" : reload_choice, } if Path(server_config_filename).is_file(): @@ -1075,16 +1078,17 @@ def generate_video( global wan_model, offloadobj reload_needed = state.get("_reload_needed", False) file_model_needed = model_needed(image2video) - if file_model_needed != model_filename or reload_needed: - if offloadobj is not None: - offloadobj.release() - offloadobj = None - wan_model = None - gc.collect() - print(f"Loading model {get_model_name(file_model_needed)}...") - wan_model, offloadobj, trans = load_models(image2video) - print(f"Model loaded") - state["_reload_needed"] = False + if(server_config.get("reload_model",1) == 2): + if file_model_needed != model_filename or reload_needed: + if offloadobj is not None: + offloadobj.release() + offloadobj = None + wan_model = None + gc.collect() + print(f"Loading model {get_model_name(file_model_needed)}...") + wan_model, offloadobj, trans = load_models(image2video) + print(f"Model loaded") + state["_reload_needed"] = False from PIL import Image import numpy as np @@ -1190,7 +1194,7 @@ def generate_video( if len(list_mult_choices_nums ) < len(loras_choices): list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] - pinnedLora = profile !=5 #False # # # + pinnedLora = False !=5 #False # # # offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors if len(errors) > 0: @@ -2236,6 +2240,14 @@ def generate_configuration_tab(): value=metadata, label="Metadata Handling" ) + reload_choice = gr.Dropdown( + choices=[ + ("When changing tabs", 1), + ("When pressing generate", 2), + ], + value=server_config.get("reload_model",1), + label="Reload model" + ) msg = gr.Markdown() apply_btn = gr.Button("Apply Changes") apply_btn.click( @@ -2253,6 +2265,7 @@ def generate_configuration_tab(): metadata_choice, default_ui_choice, boost_choice, + reload_choice, ], outputs= msg ) @@ -2270,13 +2283,24 @@ def generate_about_tab(): def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): global lora_model_filename, use_image2video - t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) - i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) - new_t2v = evt.index == 0 new_i2v = evt.index == 1 use_image2video = new_i2v + if(server_config.get("reload_model",1) == 1): + global wan_model, offloadobj + if wan_model is not None: + if offloadobj is not None: + offloadobj.release() + offloadobj = None + wan_model = None + gc.collect() + torch.cuda.empty_cache() + wan_model, offloadobj, trans = load_models(use_image2video) + + t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) + i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) + if new_t2v: lora_model_filename = t2v_state["loras_model"] if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):