mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
add config option to set whether to reload model upon changing tabs or pressing generate
This commit is contained in:
parent
646d9ace13
commit
347ab55d4a
@ -491,7 +491,8 @@ if not Path(server_config_filename).is_file():
|
|||||||
"default_ui": "t2v",
|
"default_ui": "t2v",
|
||||||
"boost" : 1,
|
"boost" : 1,
|
||||||
"vae_config": 0,
|
"vae_config": 0,
|
||||||
"profile" : profile_type.LowRAM_LowVRAM }
|
"profile" : profile_type.LowRAM_LowVRAM,
|
||||||
|
"reload_model": 1 }
|
||||||
|
|
||||||
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
||||||
writer.write(json.dumps(server_config))
|
writer.write(json.dumps(server_config))
|
||||||
@ -915,7 +916,8 @@ def apply_changes( state,
|
|||||||
vae_config_choice,
|
vae_config_choice,
|
||||||
metadata_choice,
|
metadata_choice,
|
||||||
default_ui_choice ="t2v",
|
default_ui_choice ="t2v",
|
||||||
boost_choice = 1
|
boost_choice = 1,
|
||||||
|
reload_choice = 1
|
||||||
):
|
):
|
||||||
if args.lock_config:
|
if args.lock_config:
|
||||||
return
|
return
|
||||||
@ -934,6 +936,7 @@ def apply_changes( state,
|
|||||||
"metadata_choice": metadata_choice,
|
"metadata_choice": metadata_choice,
|
||||||
"default_ui" : default_ui_choice,
|
"default_ui" : default_ui_choice,
|
||||||
"boost" : boost_choice,
|
"boost" : boost_choice,
|
||||||
|
"reload_model" : reload_choice,
|
||||||
}
|
}
|
||||||
|
|
||||||
if Path(server_config_filename).is_file():
|
if Path(server_config_filename).is_file():
|
||||||
@ -1075,6 +1078,7 @@ def generate_video(
|
|||||||
global wan_model, offloadobj
|
global wan_model, offloadobj
|
||||||
reload_needed = state.get("_reload_needed", False)
|
reload_needed = state.get("_reload_needed", False)
|
||||||
file_model_needed = model_needed(image2video)
|
file_model_needed = model_needed(image2video)
|
||||||
|
if(server_config.get("reload_model",1) == 2):
|
||||||
if file_model_needed != model_filename or reload_needed:
|
if file_model_needed != model_filename or reload_needed:
|
||||||
if offloadobj is not None:
|
if offloadobj is not None:
|
||||||
offloadobj.release()
|
offloadobj.release()
|
||||||
@ -1190,7 +1194,7 @@ def generate_video(
|
|||||||
if len(list_mult_choices_nums ) < len(loras_choices):
|
if len(list_mult_choices_nums ) < len(loras_choices):
|
||||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||||
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||||
pinnedLora = profile !=5 #False # # #
|
pinnedLora = False !=5 #False # # #
|
||||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||||
errors = trans._loras_errors
|
errors = trans._loras_errors
|
||||||
if len(errors) > 0:
|
if len(errors) > 0:
|
||||||
@ -2236,6 +2240,14 @@ def generate_configuration_tab():
|
|||||||
value=metadata,
|
value=metadata,
|
||||||
label="Metadata Handling"
|
label="Metadata Handling"
|
||||||
)
|
)
|
||||||
|
reload_choice = gr.Dropdown(
|
||||||
|
choices=[
|
||||||
|
("When changing tabs", 1),
|
||||||
|
("When pressing generate", 2),
|
||||||
|
],
|
||||||
|
value=server_config.get("reload_model",1),
|
||||||
|
label="Reload model"
|
||||||
|
)
|
||||||
msg = gr.Markdown()
|
msg = gr.Markdown()
|
||||||
apply_btn = gr.Button("Apply Changes")
|
apply_btn = gr.Button("Apply Changes")
|
||||||
apply_btn.click(
|
apply_btn.click(
|
||||||
@ -2253,6 +2265,7 @@ def generate_configuration_tab():
|
|||||||
metadata_choice,
|
metadata_choice,
|
||||||
default_ui_choice,
|
default_ui_choice,
|
||||||
boost_choice,
|
boost_choice,
|
||||||
|
reload_choice,
|
||||||
],
|
],
|
||||||
outputs= msg
|
outputs= msg
|
||||||
)
|
)
|
||||||
@ -2270,13 +2283,24 @@ def generate_about_tab():
|
|||||||
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
||||||
global lora_model_filename, use_image2video
|
global lora_model_filename, use_image2video
|
||||||
|
|
||||||
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
|
||||||
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
|
||||||
|
|
||||||
new_t2v = evt.index == 0
|
new_t2v = evt.index == 0
|
||||||
new_i2v = evt.index == 1
|
new_i2v = evt.index == 1
|
||||||
use_image2video = new_i2v
|
use_image2video = new_i2v
|
||||||
|
|
||||||
|
if(server_config.get("reload_model",1) == 1):
|
||||||
|
global wan_model, offloadobj
|
||||||
|
if wan_model is not None:
|
||||||
|
if offloadobj is not None:
|
||||||
|
offloadobj.release()
|
||||||
|
offloadobj = None
|
||||||
|
wan_model = None
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
wan_model, offloadobj, trans = load_models(use_image2video)
|
||||||
|
|
||||||
|
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
||||||
|
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
||||||
|
|
||||||
if new_t2v:
|
if new_t2v:
|
||||||
lora_model_filename = t2v_state["loras_model"]
|
lora_model_filename = t2v_state["loras_model"]
|
||||||
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user