added better missing model error handling

This commit is contained in:
deepbeepmeep 2025-07-16 00:47:23 +02:00
parent afbf94e44f
commit bda410f367
2 changed files with 27 additions and 13 deletions

View File

@ -34,6 +34,15 @@ def seed_everything(seed: int):
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
torch.mps.manual_seed(seed) torch.mps.manual_seed(seed)
def expand_slist(slist, num_inference_steps ):
new_slist= []
inc = len(slist) / num_inference_steps
pos = 0
for i in range(num_inference_steps):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def update_loras_slists(trans, slists, num_inference_steps ): def update_loras_slists(trans, slists, num_inference_steps ):
from mmgp import offload from mmgp import offload
slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ]

31
wgp.py
View File

@ -16,6 +16,7 @@ import json
import wan import wan
from wan.utils import notification_sound from wan.utils import notification_sound
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import expand_slist, update_loras_slists
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
@ -160,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice):
model_filename = state["model_filename"] model_filename = state["model_filename"]
model_type = state["model_type"] model_type = state["model_type"]
inputs = get_model_settings(state, model_type) inputs = get_model_settings(state, model_type)
if model_choice != model_type or inputs ==None: if model_choice != model_type or inputs ==None:
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
@ -1740,6 +1742,8 @@ def get_model_type(model_filename):
def get_model_family(model_type): def get_model_family(model_type):
model_type = get_base_model_type(model_type) model_type = get_base_model_type(model_type)
if model_type == None:
return "unknown"
if "hunyuan" in model_type : if "hunyuan" in model_type :
return "hunyuan" return "hunyuan"
elif "ltxv" in model_type: elif "ltxv" in model_type:
@ -1799,7 +1803,8 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ):
def get_model_name(model_type, description_container = [""]): def get_model_name(model_type, description_container = [""]):
model_def = get_model_def(model_type) model_def = get_model_def(model_type)
if model_def == None: raise Exception(f"Unknown model {model_type}") if model_def == None:
return f"Unknown model {model_type}"
model_name = model_def["name"] model_name = model_def["name"]
description = model_def["description"] description = model_def["description"]
description_container[0] = description description_container[0] = description
@ -1832,7 +1837,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m
if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") if choices == None: raise Exception(f"Invalid Module Id '{model_type}'")
else: else:
model_def = models_def.get(model_type, None) model_def = models_def.get(model_type, None)
if model_def == None: raise Exception(f"Unknown model type {model_type}") if model_def == None: return None
URLs = model_def["URLs"] URLs = model_def["URLs"]
if isinstance(URLs, str): if isinstance(URLs, str):
if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}")
@ -2120,11 +2125,18 @@ model_types = models_def.keys()
displayed_model_types= [] displayed_model_types= []
for model_type in model_types: for model_type in model_types:
model_def = get_model_def(model_type) model_def = get_model_def(model_type)
if not (model_def != None and model_def.get("visible", True) == False): if not model_def is None and model_def.get("visible", True):
displayed_model_types.append(model_type) displayed_model_types.append(model_type)
transformer_types = server_config.get("transformer_types", []) transformer_types = server_config.get("transformer_types", [])
new_transformer_types = []
for model_type in transformer_types:
if get_model_def(model_type) == None:
print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json")
else:
new_transformer_types.append(model_type)
transformer_types = new_transformer_types
transformer_type = server_config.get("last_model_type", None) transformer_type = server_config.get("last_model_type", None)
advanced = server_config.get("last_advanced_choice", False) advanced = server_config.get("last_advanced_choice", False)
if args.advanced: advanced = True if args.advanced: advanced = True
@ -2729,7 +2741,7 @@ def generate_header(model_type, compile, attention_mode):
description_container = [""] description_container = [""]
get_model_name(model_type, description_container) get_model_name(model_type, description_container)
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
description = description_container[0] description = description_container[0]
header = "<DIV style='height:40px'>" + description + "</DIV>" header = "<DIV style='height:40px'>" + description + "</DIV>"
@ -3250,14 +3262,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
html = get_default_video_info() html = get_default_video_info()
visible= len(file_list) > 0 visible= len(file_list) > 0
return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image)
def expand_slist(slist, num_inference_steps ):
new_slist= []
inc = len(slist) / num_inference_steps
pos = 0
for i in range(num_inference_steps):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def convert_image(image): def convert_image(image):
from PIL import ImageOps from PIL import ImageOps
@ -7871,7 +7876,7 @@ def generate_info_tab():
def get_sorted_dropdown(dropdown_types): def get_sorted_dropdown(dropdown_types):
families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 } families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3, "unknown": 100 }
dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types] dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types]
dropdown_names = [ get_model_name(type) for type in dropdown_types] dropdown_names = [ get_model_name(type) for type in dropdown_types]