added default models selection option

This commit is contained in:
DeepBeepMeep 2025-04-07 23:18:47 +02:00
parent fbf2793b2e
commit 5f280408f5

162
wgp.py
View File

@ -441,6 +441,12 @@ def _parse_args():
help="Prevent modifying the configuration from the web interface" help="Prevent modifying the configuration from the web interface"
) )
parser.add_argument(
"--lock-model",
action="store_true",
help="Prevent switch models"
)
parser.add_argument( parser.add_argument(
"--preload", "--preload",
type=str, type=str,
@ -700,11 +706,15 @@ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
server_config_filename = "gradio_config.json" server_config_filename = "wgp_config.json"
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
import shutil
shutil.move("gradio_config.json", server_config_filename)
if not Path(server_config_filename).is_file(): if not Path(server_config_filename).is_file():
server_config = {"attention_mode" : "auto", server_config = {"attention_mode" : "auto",
"transformer_type": "t2v", "transformer_types": [],
"transformer_quantization": "int8", "transformer_quantization": "int8",
"text_encoder_filename" : text_encoder_choices[1], "text_encoder_filename" : text_encoder_choices[1],
"save_path": os.path.join(os.getcwd(), "gradio_outputs"), "save_path": os.path.join(os.getcwd(), "gradio_outputs"),
@ -826,7 +836,8 @@ def get_default_settings(filename):
ui_defaults["num_inference_steps"] = default_number_steps ui_defaults["num_inference_steps"] = default_number_steps
return ui_defaults return ui_defaults
transformer_type = server_config.get("transformer_type", "t2v") transformer_types = server_config.get("transformer_types", [])
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
transformer_quantization =server_config.get("transformer_quantization", "int8") transformer_quantization =server_config.get("transformer_quantization", "int8")
transformer_filename = get_model_filename(transformer_type, transformer_quantization) transformer_filename = get_model_filename(transformer_type, transformer_quantization)
text_encoder_filename = server_config["text_encoder_filename"] text_encoder_filename = server_config["text_encoder_filename"]
@ -1213,22 +1224,25 @@ def get_model_name(model_filename):
# return header # return header
def generate_header(compile, attention_mode): def generate_header(model_filename, compile, attention_mode):
header = "<DIV style='align:right;width:100%'><FONT SIZE=2>Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) header = "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
if attention_mode not in attention_modes_installed: if attention_mode not in attention_modes_installed:
header += " -NOT INSTALLED-" header += " -NOT INSTALLED-"
elif attention_mode not in attention_modes_supported: elif attention_mode not in attention_modes_supported:
header += " -NOT SUPPORTED-" header += " -NOT SUPPORTED-"
header += "</B>"
if compile: if compile:
header += ", pytorch compilation ON" header += ", Pytorch compilation <B>ON</B>"
if "int8" in model_filename:
header += ", Quantization <B>Int8</B>"
header += "<FONT></DIV>" header += "<FONT></DIV>"
return header return header
def apply_changes( state, def apply_changes( state,
transformer_type_choice, transformer_types_choices,
text_encoder_choice, text_encoder_choice,
save_path_choice, save_path_choice,
attention_choice, attention_choice,
@ -1239,16 +1253,15 @@ def apply_changes( state,
quantization_choice, quantization_choice,
boost_choice = 1, boost_choice = 1,
clear_file_list = 0, clear_file_list = 0,
reload_choice = 1 reload_choice = 1,
): ):
if args.lock_config: if args.lock_config:
return return
if gen_in_progress: if gen_in_progress:
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>" return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
return
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
server_config = {"attention_mode" : attention_choice, server_config = {"attention_mode" : attention_choice,
"transformer_type": transformer_type_choice, "transformer_types": transformer_types_choices,
"text_encoder_filename" : text_encoder_choices[text_encoder_choice], "text_encoder_filename" : text_encoder_choices[text_encoder_choice],
"save_path" : save_path_choice, "save_path" : save_path_choice,
"compile" : compile_choice, "compile" : compile_choice,
@ -1281,7 +1294,7 @@ def apply_changes( state,
if v != v_old: if v != v_old:
changes.append(k) changes.append(k)
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
attention_mode = server_config["attention_mode"] attention_mode = server_config["attention_mode"]
profile = server_config["profile"] profile = server_config["profile"]
compile = server_config["compile"] compile = server_config["compile"]
@ -1290,15 +1303,19 @@ def apply_changes( state,
boost = server_config["boost"] boost = server_config["boost"]
reload_model = server_config["reload_model"] reload_model = server_config["reload_model"]
transformer_quantization = server_config["transformer_quantization"] transformer_quantization = server_config["transformer_quantization"]
transformer_filename = get_model_filename(transformer_type, transformer_quantization) transformer_types = server_config["transformer_types"]
transformer_type = get_model_type(transformer_filename)
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ): if not transformer_type in transformer_types:
pass transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
model_choice = gr.Dropdown()
else: else:
reload_needed = True reload_needed = True
model_choice = generate_dropdown_model_list()
header = generate_header(transformer_filename, compile=compile, attention_mode= attention_mode)
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>" return "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>", header, model_choice
@ -2505,16 +2522,11 @@ def handle_celll_selection(state, evt: gr.SelectData):
def change_model(state, model_choice): def change_model(state, model_choice):
model_filename = "" if model_choice == None:
for filename in model_list:
if get_model_type(filename) == model_choice:
model_filename = filename
break
if len(model_filename) == 0:
return return
model_filename = get_model_filename(model_choice, transformer_quantization)
state["model_filename"] = model_filename state["model_filename"] = model_filename
header = generate_header(compile=compile, attention_mode=attention_mode) header = generate_header(model_filename, compile=compile, attention_mode=attention_mode)
return header return header
def fill_inputs(state): def fill_inputs(state):
@ -3014,53 +3026,32 @@ def generate_download_tab(lset_name,loras_choices, state):
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
def generate_configuration_tab(): def generate_configuration_tab(header, model_choice):
state_dict = {} state_dict = {}
state = gr.State(state_dict) state = gr.State(state_dict)
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
with gr.Column(): with gr.Column():
index = transformer_choices.index(transformer_filename)
index = 0 if index ==0 else index
model_list = [] model_list = []
for model_type in model_types: for model_type in model_types:
choice = get_model_filename(model_type, transformer_quantization) choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice) model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
transformer_type_choice = gr.Dropdown( transformer_types_choices = gr.Dropdown(
choices= dropdown_choices, choices= dropdown_choices,
value= get_model_type(transformer_filename), value= transformer_types,
label= "Default Wan Transformer Model", label= "Selectable Wan Transformer Models (keep empty to get All of them)",
scale= 2 scale= 2,
multiselect= True
) )
# transformer_choice = gr.Dropdown(
# choices=[
# ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
# ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
# ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
# ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
# ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4),
# ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5),
# ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6),
# ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7),
# ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8),
# ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9),
# ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10),
# ],
# value= index,
# label="Transformer model for Image to Video",
# interactive= not lock_ui_transformer,
# visible = True,
# )
quantization_choice = gr.Dropdown( quantization_choice = gr.Dropdown(
choices=[ choices=[
("Int8 Quantization (recommended)", "int8"), ("Int8 Quantization (recommended)", "int8"),
("BF16 (no quantization)", "bf16"), ("BF16 (no quantization)", "bf16"),
], ],
value= transformer_quantization, value= transformer_quantization,
label="Wan Transformer Model Quantization (if available)", label="Wan Transformer Model Quantization Type (if available)",
) )
index = text_encoder_choices.index(text_encoder_filename) index = text_encoder_choices.index(text_encoder_filename)
@ -3137,14 +3128,14 @@ def generate_configuration_tab():
value= profile, value= profile,
label="Profile (for power users only, not needed to change it)" label="Profile (for power users only, not needed to change it)"
) )
default_ui_choice = gr.Dropdown( # default_ui_choice = gr.Dropdown(
choices=[ # choices=[
("Text to Video", "t2v"), # ("Text to Video", "t2v"),
("Image to Video", "i2v"), # ("Image to Video", "i2v"),
], # ],
value= default_ui, # value= default_ui,
label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ", # label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
) # )
metadata_choice = gr.Dropdown( metadata_choice = gr.Dropdown(
choices=[ choices=[
("Export JSON files", "json"), ("Export JSON files", "json"),
@ -3184,7 +3175,7 @@ def generate_configuration_tab():
fn=apply_changes, fn=apply_changes,
inputs=[ inputs=[
state, state,
transformer_type_choice, transformer_types_choices,
text_encoder_choice, text_encoder_choice,
save_path_choice, save_path_choice,
attention_choice, attention_choice,
@ -3197,7 +3188,7 @@ def generate_configuration_tab():
clear_file_list_choice, clear_file_list_choice,
reload_choice, reload_choice,
], ],
outputs= msg outputs= [msg , header, model_choice]
) )
def generate_about_tab(): def generate_about_tab():
@ -3221,6 +3212,22 @@ def generate_info_tab():
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
def generate_dropdown_model_list():
dropdown_types= transformer_types if len(transformer_types) > 0 else model_types
current_model_type = get_model_type(transformer_filename)
if current_model_type not in dropdown_types:
dropdown_types.append(current_model_type)
model_list = []
for model_type in dropdown_types:
choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
return gr.Dropdown(
choices= dropdown_choices,
value= current_model_type,
show_label= False,
scale= 2
)
@ -3457,22 +3464,15 @@ def create_demo():
with gr.Tabs(selected="video_gen", ) as main_tabs: with gr.Tabs(selected="video_gen", ) as main_tabs:
with gr.Tab("Video Generator", id="video_gen") as t2v_tab: with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
with gr.Row(): with gr.Row():
header = gr.Markdown(generate_header(compile, attention_mode), visible= True) if args.lock_model:
gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
model_choice = gr.Dropdown(visible=False, value= get_model_type(transformer_filename))
else:
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
model_choice = generate_dropdown_model_list()
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
with gr.Row(): with gr.Row():
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>") header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
model_list = []
for model_type in model_types:
choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
model_choice = gr.Dropdown(
choices= dropdown_choices,
value= get_model_type(transformer_filename),
show_label= False,
scale= 2
)
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
with gr.Row(): with gr.Row():
loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header) loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
@ -3482,7 +3482,7 @@ def create_demo():
with gr.Tab("Downloads", id="downloads") as downloads_tab: with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_download_tab(lset_name, loras_choices, state) generate_download_tab(lset_name, loras_choices, state)
with gr.Tab("Configuration"): with gr.Tab("Configuration"):
generate_configuration_tab() generate_configuration_tab(header, model_choice)
with gr.Tab("About"): with gr.Tab("About"):
generate_about_tab() generate_about_tab()