add config option to set whether to reload model upon changing tabs or pressing generate

This commit is contained in:
Chris Malone 2025-03-24 22:23:51 +11:00
parent 646d9ace13
commit 347ab55d4a

View File

@ -491,7 +491,8 @@ if not Path(server_config_filename).is_file():
"default_ui": "t2v",
"boost" : 1,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM }
"profile" : profile_type.LowRAM_LowVRAM,
"reload_model": 1 }
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config))
@ -915,7 +916,8 @@ def apply_changes( state,
vae_config_choice,
metadata_choice,
default_ui_choice ="t2v",
boost_choice = 1
boost_choice = 1,
reload_choice = 1
):
if args.lock_config:
return
@ -934,6 +936,7 @@ def apply_changes( state,
"metadata_choice": metadata_choice,
"default_ui" : default_ui_choice,
"boost" : boost_choice,
"reload_model" : reload_choice,
}
if Path(server_config_filename).is_file():
@ -1075,6 +1078,7 @@ def generate_video(
global wan_model, offloadobj
reload_needed = state.get("_reload_needed", False)
file_model_needed = model_needed(image2video)
if(server_config.get("reload_model",1) == 2):
if file_model_needed != model_filename or reload_needed:
if offloadobj is not None:
offloadobj.release()
@ -1190,7 +1194,7 @@ def generate_video(
if len(list_mult_choices_nums ) < len(loras_choices):
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
pinnedLora = profile !=5 #False # # #
pinnedLora = False !=5 #False # # #
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors
if len(errors) > 0:
@ -2236,6 +2240,14 @@ def generate_configuration_tab():
value=metadata,
label="Metadata Handling"
)
reload_choice = gr.Dropdown(
choices=[
("When changing tabs", 1),
("When pressing generate", 2),
],
value=server_config.get("reload_model",1),
label="Reload model"
)
msg = gr.Markdown()
apply_btn = gr.Button("Apply Changes")
apply_btn.click(
@ -2253,6 +2265,7 @@ def generate_configuration_tab():
metadata_choice,
default_ui_choice,
boost_choice,
reload_choice,
],
outputs= msg
)
@ -2270,13 +2283,24 @@ def generate_about_tab():
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
global lora_model_filename, use_image2video
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
new_t2v = evt.index == 0
new_i2v = evt.index == 1
use_image2video = new_i2v
if(server_config.get("reload_model",1) == 1):
global wan_model, offloadobj
if wan_model is not None:
if offloadobj is not None:
offloadobj.release()
offloadobj = None
wan_model = None
gc.collect()
torch.cuda.empty_cache()
wan_model, offloadobj, trans = load_models(use_image2video)
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
if new_t2v:
lora_model_filename = t2v_state["loras_model"]
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):