mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
added default models selection option
This commit is contained in:
parent
fbf2793b2e
commit
5f280408f5
160
wgp.py
160
wgp.py
@ -441,6 +441,12 @@ def _parse_args():
|
|||||||
help="Prevent modifying the configuration from the web interface"
|
help="Prevent modifying the configuration from the web interface"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lock-model",
|
||||||
|
action="store_true",
|
||||||
|
help="Prevent switch models"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preload",
|
"--preload",
|
||||||
type=str,
|
type=str,
|
||||||
@ -700,11 +706,15 @@ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts
|
|||||||
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
||||||
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
||||||
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
||||||
server_config_filename = "gradio_config.json"
|
server_config_filename = "wgp_config.json"
|
||||||
|
|
||||||
|
if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
|
||||||
|
import shutil
|
||||||
|
shutil.move("gradio_config.json", server_config_filename)
|
||||||
|
|
||||||
if not Path(server_config_filename).is_file():
|
if not Path(server_config_filename).is_file():
|
||||||
server_config = {"attention_mode" : "auto",
|
server_config = {"attention_mode" : "auto",
|
||||||
"transformer_type": "t2v",
|
"transformer_types": [],
|
||||||
"transformer_quantization": "int8",
|
"transformer_quantization": "int8",
|
||||||
"text_encoder_filename" : text_encoder_choices[1],
|
"text_encoder_filename" : text_encoder_choices[1],
|
||||||
"save_path": os.path.join(os.getcwd(), "gradio_outputs"),
|
"save_path": os.path.join(os.getcwd(), "gradio_outputs"),
|
||||||
@ -826,7 +836,8 @@ def get_default_settings(filename):
|
|||||||
ui_defaults["num_inference_steps"] = default_number_steps
|
ui_defaults["num_inference_steps"] = default_number_steps
|
||||||
return ui_defaults
|
return ui_defaults
|
||||||
|
|
||||||
transformer_type = server_config.get("transformer_type", "t2v")
|
transformer_types = server_config.get("transformer_types", [])
|
||||||
|
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
||||||
transformer_quantization =server_config.get("transformer_quantization", "int8")
|
transformer_quantization =server_config.get("transformer_quantization", "int8")
|
||||||
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
|
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
|
||||||
text_encoder_filename = server_config["text_encoder_filename"]
|
text_encoder_filename = server_config["text_encoder_filename"]
|
||||||
@ -1213,22 +1224,25 @@ def get_model_name(model_filename):
|
|||||||
# return header
|
# return header
|
||||||
|
|
||||||
|
|
||||||
def generate_header(compile, attention_mode):
|
def generate_header(model_filename, compile, attention_mode):
|
||||||
|
|
||||||
header = "<DIV style='align:right;width:100%'><FONT SIZE=2>Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
header = "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
||||||
if attention_mode not in attention_modes_installed:
|
if attention_mode not in attention_modes_installed:
|
||||||
header += " -NOT INSTALLED-"
|
header += " -NOT INSTALLED-"
|
||||||
elif attention_mode not in attention_modes_supported:
|
elif attention_mode not in attention_modes_supported:
|
||||||
header += " -NOT SUPPORTED-"
|
header += " -NOT SUPPORTED-"
|
||||||
|
header += "</B>"
|
||||||
|
|
||||||
if compile:
|
if compile:
|
||||||
header += ", pytorch compilation ON"
|
header += ", Pytorch compilation <B>ON</B>"
|
||||||
|
if "int8" in model_filename:
|
||||||
|
header += ", Quantization <B>Int8</B>"
|
||||||
header += "<FONT></DIV>"
|
header += "<FONT></DIV>"
|
||||||
|
|
||||||
return header
|
return header
|
||||||
|
|
||||||
def apply_changes( state,
|
def apply_changes( state,
|
||||||
transformer_type_choice,
|
transformer_types_choices,
|
||||||
text_encoder_choice,
|
text_encoder_choice,
|
||||||
save_path_choice,
|
save_path_choice,
|
||||||
attention_choice,
|
attention_choice,
|
||||||
@ -1239,16 +1253,15 @@ def apply_changes( state,
|
|||||||
quantization_choice,
|
quantization_choice,
|
||||||
boost_choice = 1,
|
boost_choice = 1,
|
||||||
clear_file_list = 0,
|
clear_file_list = 0,
|
||||||
reload_choice = 1
|
reload_choice = 1,
|
||||||
):
|
):
|
||||||
if args.lock_config:
|
if args.lock_config:
|
||||||
return
|
return
|
||||||
if gen_in_progress:
|
if gen_in_progress:
|
||||||
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
||||||
return
|
|
||||||
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
||||||
server_config = {"attention_mode" : attention_choice,
|
server_config = {"attention_mode" : attention_choice,
|
||||||
"transformer_type": transformer_type_choice,
|
"transformer_types": transformer_types_choices,
|
||||||
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
|
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
|
||||||
"save_path" : save_path_choice,
|
"save_path" : save_path_choice,
|
||||||
"compile" : compile_choice,
|
"compile" : compile_choice,
|
||||||
@ -1281,7 +1294,7 @@ def apply_changes( state,
|
|||||||
if v != v_old:
|
if v != v_old:
|
||||||
changes.append(k)
|
changes.append(k)
|
||||||
|
|
||||||
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type
|
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
|
||||||
attention_mode = server_config["attention_mode"]
|
attention_mode = server_config["attention_mode"]
|
||||||
profile = server_config["profile"]
|
profile = server_config["profile"]
|
||||||
compile = server_config["compile"]
|
compile = server_config["compile"]
|
||||||
@ -1290,15 +1303,19 @@ def apply_changes( state,
|
|||||||
boost = server_config["boost"]
|
boost = server_config["boost"]
|
||||||
reload_model = server_config["reload_model"]
|
reload_model = server_config["reload_model"]
|
||||||
transformer_quantization = server_config["transformer_quantization"]
|
transformer_quantization = server_config["transformer_quantization"]
|
||||||
|
transformer_types = server_config["transformer_types"]
|
||||||
|
transformer_type = get_model_type(transformer_filename)
|
||||||
|
if not transformer_type in transformer_types:
|
||||||
|
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
||||||
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
|
transformer_filename = get_model_filename(transformer_type, transformer_quantization)
|
||||||
|
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
||||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
model_choice = gr.Dropdown()
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
reload_needed = True
|
reload_needed = True
|
||||||
|
model_choice = generate_dropdown_model_list()
|
||||||
|
|
||||||
|
header = generate_header(transformer_filename, compile=compile, attention_mode= attention_mode)
|
||||||
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
return "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>", header, model_choice
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -2505,16 +2522,11 @@ def handle_celll_selection(state, evt: gr.SelectData):
|
|||||||
|
|
||||||
|
|
||||||
def change_model(state, model_choice):
|
def change_model(state, model_choice):
|
||||||
model_filename = ""
|
if model_choice == None:
|
||||||
for filename in model_list:
|
|
||||||
if get_model_type(filename) == model_choice:
|
|
||||||
model_filename = filename
|
|
||||||
break
|
|
||||||
if len(model_filename) == 0:
|
|
||||||
return
|
return
|
||||||
|
model_filename = get_model_filename(model_choice, transformer_quantization)
|
||||||
state["model_filename"] = model_filename
|
state["model_filename"] = model_filename
|
||||||
header = generate_header(compile=compile, attention_mode=attention_mode)
|
header = generate_header(model_filename, compile=compile, attention_mode=attention_mode)
|
||||||
return header
|
return header
|
||||||
|
|
||||||
def fill_inputs(state):
|
def fill_inputs(state):
|
||||||
@ -3014,53 +3026,32 @@ def generate_download_tab(lset_name,loras_choices, state):
|
|||||||
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||||
|
|
||||||
|
|
||||||
def generate_configuration_tab():
|
def generate_configuration_tab(header, model_choice):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
state = gr.State(state_dict)
|
state = gr.State(state_dict)
|
||||||
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
index = transformer_choices.index(transformer_filename)
|
|
||||||
index = 0 if index ==0 else index
|
|
||||||
|
|
||||||
model_list = []
|
model_list = []
|
||||||
|
|
||||||
for model_type in model_types:
|
for model_type in model_types:
|
||||||
choice = get_model_filename(model_type, transformer_quantization)
|
choice = get_model_filename(model_type, transformer_quantization)
|
||||||
model_list.append(choice)
|
model_list.append(choice)
|
||||||
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
||||||
transformer_type_choice = gr.Dropdown(
|
transformer_types_choices = gr.Dropdown(
|
||||||
choices= dropdown_choices,
|
choices= dropdown_choices,
|
||||||
value= get_model_type(transformer_filename),
|
value= transformer_types,
|
||||||
label= "Default Wan Transformer Model",
|
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
|
||||||
scale= 2
|
scale= 2,
|
||||||
|
multiselect= True
|
||||||
)
|
)
|
||||||
|
|
||||||
# transformer_choice = gr.Dropdown(
|
|
||||||
# choices=[
|
|
||||||
# ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
|
|
||||||
# ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
|
||||||
# ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
|
||||||
# ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
|
|
||||||
# ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4),
|
|
||||||
# ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5),
|
|
||||||
# ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6),
|
|
||||||
# ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7),
|
|
||||||
# ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8),
|
|
||||||
# ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9),
|
|
||||||
# ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10),
|
|
||||||
# ],
|
|
||||||
# value= index,
|
|
||||||
# label="Transformer model for Image to Video",
|
|
||||||
# interactive= not lock_ui_transformer,
|
|
||||||
# visible = True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
quantization_choice = gr.Dropdown(
|
quantization_choice = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
("Int8 Quantization (recommended)", "int8"),
|
("Int8 Quantization (recommended)", "int8"),
|
||||||
("BF16 (no quantization)", "bf16"),
|
("BF16 (no quantization)", "bf16"),
|
||||||
],
|
],
|
||||||
value= transformer_quantization,
|
value= transformer_quantization,
|
||||||
label="Wan Transformer Model Quantization (if available)",
|
label="Wan Transformer Model Quantization Type (if available)",
|
||||||
)
|
)
|
||||||
|
|
||||||
index = text_encoder_choices.index(text_encoder_filename)
|
index = text_encoder_choices.index(text_encoder_filename)
|
||||||
@ -3137,14 +3128,14 @@ def generate_configuration_tab():
|
|||||||
value= profile,
|
value= profile,
|
||||||
label="Profile (for power users only, not needed to change it)"
|
label="Profile (for power users only, not needed to change it)"
|
||||||
)
|
)
|
||||||
default_ui_choice = gr.Dropdown(
|
# default_ui_choice = gr.Dropdown(
|
||||||
choices=[
|
# choices=[
|
||||||
("Text to Video", "t2v"),
|
# ("Text to Video", "t2v"),
|
||||||
("Image to Video", "i2v"),
|
# ("Image to Video", "i2v"),
|
||||||
],
|
# ],
|
||||||
value= default_ui,
|
# value= default_ui,
|
||||||
label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
|
# label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
|
||||||
)
|
# )
|
||||||
metadata_choice = gr.Dropdown(
|
metadata_choice = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
("Export JSON files", "json"),
|
("Export JSON files", "json"),
|
||||||
@ -3184,7 +3175,7 @@ def generate_configuration_tab():
|
|||||||
fn=apply_changes,
|
fn=apply_changes,
|
||||||
inputs=[
|
inputs=[
|
||||||
state,
|
state,
|
||||||
transformer_type_choice,
|
transformer_types_choices,
|
||||||
text_encoder_choice,
|
text_encoder_choice,
|
||||||
save_path_choice,
|
save_path_choice,
|
||||||
attention_choice,
|
attention_choice,
|
||||||
@ -3197,7 +3188,7 @@ def generate_configuration_tab():
|
|||||||
clear_file_list_choice,
|
clear_file_list_choice,
|
||||||
reload_choice,
|
reload_choice,
|
||||||
],
|
],
|
||||||
outputs= msg
|
outputs= [msg , header, model_choice]
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_about_tab():
|
def generate_about_tab():
|
||||||
@ -3221,6 +3212,22 @@ def generate_info_tab():
|
|||||||
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dropdown_model_list():
|
||||||
|
dropdown_types= transformer_types if len(transformer_types) > 0 else model_types
|
||||||
|
current_model_type = get_model_type(transformer_filename)
|
||||||
|
if current_model_type not in dropdown_types:
|
||||||
|
dropdown_types.append(current_model_type)
|
||||||
|
model_list = []
|
||||||
|
for model_type in dropdown_types:
|
||||||
|
choice = get_model_filename(model_type, transformer_quantization)
|
||||||
|
model_list.append(choice)
|
||||||
|
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
||||||
|
return gr.Dropdown(
|
||||||
|
choices= dropdown_choices,
|
||||||
|
value= current_model_type,
|
||||||
|
show_label= False,
|
||||||
|
scale= 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -3457,22 +3464,15 @@ def create_demo():
|
|||||||
with gr.Tabs(selected="video_gen", ) as main_tabs:
|
with gr.Tabs(selected="video_gen", ) as main_tabs:
|
||||||
with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
|
with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
header = gr.Markdown(generate_header(compile, attention_mode), visible= True)
|
if args.lock_model:
|
||||||
|
gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
|
||||||
|
model_choice = gr.Dropdown(visible=False, value= get_model_type(transformer_filename))
|
||||||
|
else:
|
||||||
|
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
|
||||||
|
model_choice = generate_dropdown_model_list()
|
||||||
|
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
|
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
|
||||||
|
|
||||||
model_list = []
|
|
||||||
for model_type in model_types:
|
|
||||||
choice = get_model_filename(model_type, transformer_quantization)
|
|
||||||
model_list.append(choice)
|
|
||||||
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
|
||||||
model_choice = gr.Dropdown(
|
|
||||||
choices= dropdown_choices,
|
|
||||||
value= get_model_type(transformer_filename),
|
|
||||||
show_label= False,
|
|
||||||
scale= 2
|
|
||||||
)
|
|
||||||
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
|
loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
|
||||||
@ -3482,7 +3482,7 @@ def create_demo():
|
|||||||
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
||||||
generate_download_tab(lset_name, loras_choices, state)
|
generate_download_tab(lset_name, loras_choices, state)
|
||||||
with gr.Tab("Configuration"):
|
with gr.Tab("Configuration"):
|
||||||
generate_configuration_tab()
|
generate_configuration_tab(header, model_choice)
|
||||||
with gr.Tab("About"):
|
with gr.Tab("About"):
|
||||||
generate_about_tab()
|
generate_about_tab()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user