mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
add config option to set whether to reload model upon changing tabs or pressing generate
This commit is contained in:
parent
646d9ace13
commit
347ab55d4a
@ -491,7 +491,8 @@ if not Path(server_config_filename).is_file():
|
||||
"default_ui": "t2v",
|
||||
"boost" : 1,
|
||||
"vae_config": 0,
|
||||
"profile" : profile_type.LowRAM_LowVRAM }
|
||||
"profile" : profile_type.LowRAM_LowVRAM,
|
||||
"reload_model": 1 }
|
||||
|
||||
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
||||
writer.write(json.dumps(server_config))
|
||||
@ -915,7 +916,8 @@ def apply_changes( state,
|
||||
vae_config_choice,
|
||||
metadata_choice,
|
||||
default_ui_choice ="t2v",
|
||||
boost_choice = 1
|
||||
boost_choice = 1,
|
||||
reload_choice = 1
|
||||
):
|
||||
if args.lock_config:
|
||||
return
|
||||
@ -934,6 +936,7 @@ def apply_changes( state,
|
||||
"metadata_choice": metadata_choice,
|
||||
"default_ui" : default_ui_choice,
|
||||
"boost" : boost_choice,
|
||||
"reload_model" : reload_choice,
|
||||
}
|
||||
|
||||
if Path(server_config_filename).is_file():
|
||||
@ -1075,16 +1078,17 @@ def generate_video(
|
||||
global wan_model, offloadobj
|
||||
reload_needed = state.get("_reload_needed", False)
|
||||
file_model_needed = model_needed(image2video)
|
||||
if file_model_needed != model_filename or reload_needed:
|
||||
if offloadobj is not None:
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
wan_model = None
|
||||
gc.collect()
|
||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||
wan_model, offloadobj, trans = load_models(image2video)
|
||||
print(f"Model loaded")
|
||||
state["_reload_needed"] = False
|
||||
if(server_config.get("reload_model",1) == 2):
|
||||
if file_model_needed != model_filename or reload_needed:
|
||||
if offloadobj is not None:
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
wan_model = None
|
||||
gc.collect()
|
||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||
wan_model, offloadobj, trans = load_models(image2video)
|
||||
print(f"Model loaded")
|
||||
state["_reload_needed"] = False
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@ -1190,7 +1194,7 @@ def generate_video(
|
||||
if len(list_mult_choices_nums ) < len(loras_choices):
|
||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||
pinnedLora = profile !=5 #False # # #
|
||||
pinnedLora = False !=5 #False # # #
|
||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||
errors = trans._loras_errors
|
||||
if len(errors) > 0:
|
||||
@ -2236,6 +2240,14 @@ def generate_configuration_tab():
|
||||
value=metadata,
|
||||
label="Metadata Handling"
|
||||
)
|
||||
reload_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("When changing tabs", 1),
|
||||
("When pressing generate", 2),
|
||||
],
|
||||
value=server_config.get("reload_model",1),
|
||||
label="Reload model"
|
||||
)
|
||||
msg = gr.Markdown()
|
||||
apply_btn = gr.Button("Apply Changes")
|
||||
apply_btn.click(
|
||||
@ -2253,6 +2265,7 @@ def generate_configuration_tab():
|
||||
metadata_choice,
|
||||
default_ui_choice,
|
||||
boost_choice,
|
||||
reload_choice,
|
||||
],
|
||||
outputs= msg
|
||||
)
|
||||
@ -2270,13 +2283,24 @@ def generate_about_tab():
|
||||
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
||||
global lora_model_filename, use_image2video
|
||||
|
||||
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
||||
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
||||
|
||||
new_t2v = evt.index == 0
|
||||
new_i2v = evt.index == 1
|
||||
use_image2video = new_i2v
|
||||
|
||||
if(server_config.get("reload_model",1) == 1):
|
||||
global wan_model, offloadobj
|
||||
if wan_model is not None:
|
||||
if offloadobj is not None:
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
wan_model = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
wan_model, offloadobj, trans = load_models(use_image2video)
|
||||
|
||||
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
||||
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
||||
|
||||
if new_t2v:
|
||||
lora_model_filename = t2v_state["loras_model"]
|
||||
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user