add config option to set whether to reload model upon changing tabs or pressing generate

This commit is contained in:
Chris Malone 2025-03-24 22:23:51 +11:00
parent 646d9ace13
commit 347ab55d4a

View File

@ -491,7 +491,8 @@ if not Path(server_config_filename).is_file():
"default_ui": "t2v", "default_ui": "t2v",
"boost" : 1, "boost" : 1,
"vae_config": 0, "vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM } "profile" : profile_type.LowRAM_LowVRAM,
"reload_model": 1 }
with open(server_config_filename, "w", encoding="utf-8") as writer: with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config)) writer.write(json.dumps(server_config))
@ -915,7 +916,8 @@ def apply_changes( state,
vae_config_choice, vae_config_choice,
metadata_choice, metadata_choice,
default_ui_choice ="t2v", default_ui_choice ="t2v",
boost_choice = 1 boost_choice = 1,
reload_choice = 1
): ):
if args.lock_config: if args.lock_config:
return return
@ -934,6 +936,7 @@ def apply_changes( state,
"metadata_choice": metadata_choice, "metadata_choice": metadata_choice,
"default_ui" : default_ui_choice, "default_ui" : default_ui_choice,
"boost" : boost_choice, "boost" : boost_choice,
"reload_model" : reload_choice,
} }
if Path(server_config_filename).is_file(): if Path(server_config_filename).is_file():
@ -1075,6 +1078,7 @@ def generate_video(
global wan_model, offloadobj global wan_model, offloadobj
reload_needed = state.get("_reload_needed", False) reload_needed = state.get("_reload_needed", False)
file_model_needed = model_needed(image2video) file_model_needed = model_needed(image2video)
if(server_config.get("reload_model",1) == 2):
if file_model_needed != model_filename or reload_needed: if file_model_needed != model_filename or reload_needed:
if offloadobj is not None: if offloadobj is not None:
offloadobj.release() offloadobj.release()
@ -1190,7 +1194,7 @@ def generate_video(
if len(list_mult_choices_nums ) < len(loras_choices): if len(list_mult_choices_nums ) < len(loras_choices):
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
pinnedLora = profile !=5 #False # # # pinnedLora = False !=5 #False # # #
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors errors = trans._loras_errors
if len(errors) > 0: if len(errors) > 0:
@ -2236,6 +2240,14 @@ def generate_configuration_tab():
value=metadata, value=metadata,
label="Metadata Handling" label="Metadata Handling"
) )
reload_choice = gr.Dropdown(
choices=[
("When changing tabs", 1),
("When pressing generate", 2),
],
value=server_config.get("reload_model",1),
label="Reload model"
)
msg = gr.Markdown() msg = gr.Markdown()
apply_btn = gr.Button("Apply Changes") apply_btn = gr.Button("Apply Changes")
apply_btn.click( apply_btn.click(
@ -2253,6 +2265,7 @@ def generate_configuration_tab():
metadata_choice, metadata_choice,
default_ui_choice, default_ui_choice,
boost_choice, boost_choice,
reload_choice,
], ],
outputs= msg outputs= msg
) )
@ -2270,13 +2283,24 @@ def generate_about_tab():
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
global lora_model_filename, use_image2video global lora_model_filename, use_image2video
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
new_t2v = evt.index == 0 new_t2v = evt.index == 0
new_i2v = evt.index == 1 new_i2v = evt.index == 1
use_image2video = new_i2v use_image2video = new_i2v
if(server_config.get("reload_model",1) == 1):
global wan_model, offloadobj
if wan_model is not None:
if offloadobj is not None:
offloadobj.release()
offloadobj = None
wan_model = None
gc.collect()
torch.cuda.empty_cache()
wan_model, offloadobj, trans = load_models(use_image2video)
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
if new_t2v: if new_t2v:
lora_model_filename = t2v_state["loras_model"] lora_model_filename = t2v_state["loras_model"]
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename): if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):