From bda410f367b452a5d23f34206c34e5af29d5d7d3 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 00:47:23 +0200 Subject: [PATCH] added better missing model error handling --- wan/utils/utils.py | 9 +++++++++ wgp.py | 31 ++++++++++++++++++------------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 9e61fcb..cbd34e9 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -34,6 +34,15 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def expand_slist(slist, num_inference_steps ): + new_slist= [] + inc = len(slist) / num_inference_steps + pos = 0 + for i in range(num_inference_steps): + new_slist.append(slist[ int(pos)]) + pos += inc + return new_slist + def update_loras_slists(trans, slists, num_inference_steps ): from mmgp import offload slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] diff --git a/wgp.py b/wgp.py index 7888417..ab12f50 100644 --- a/wgp.py +++ b/wgp.py @@ -16,6 +16,7 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS +from wan.utils.utils import expand_slist, update_loras_slists from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files @@ -160,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice): model_filename = state["model_filename"] model_type = state["model_type"] inputs = get_model_settings(state, model_type) + if model_choice != model_type or inputs ==None: raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") @@ -1740,6 +1742,8 @@ def get_model_type(model_filename): def get_model_family(model_type): model_type = get_base_model_type(model_type) + if model_type == None: + return "unknown" if "hunyuan" in model_type : return "hunyuan" elif "ltxv" in model_type: @@ -1799,7 +1803,8 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): def get_model_name(model_type, description_container = [""]): model_def = get_model_def(model_type) - if model_def == None: raise Exception(f"Unknown model {model_type}") + if model_def == None: + return f"Unknown model {model_type}" model_name = model_def["name"] description = model_def["description"] description_container[0] = description @@ -1832,7 +1837,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") else: model_def = models_def.get(model_type, None) - if model_def == None: raise Exception(f"Unknown model type {model_type}") + if model_def == None: return None URLs = model_def["URLs"] if isinstance(URLs, str): if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") @@ -2120,11 +2125,18 @@ model_types = models_def.keys() displayed_model_types= [] for model_type in model_types: model_def = get_model_def(model_type) - if not (model_def != None and model_def.get("visible", True) == False): + if not model_def is None and model_def.get("visible", True): displayed_model_types.append(model_type) transformer_types = server_config.get("transformer_types", []) +new_transformer_types = [] +for model_type in transformer_types: + if get_model_def(model_type) == None: + print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json") + else: + new_transformer_types.append(model_type) +transformer_types = new_transformer_types transformer_type = server_config.get("last_model_type", None) advanced = server_config.get("last_advanced_choice", False) if args.advanced: advanced = True @@ -2729,7 +2741,7 @@ def generate_header(model_type, compile, attention_mode): description_container = [""] get_model_name(model_type, description_container) - model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) + model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" description = description_container[0] header = "
" + description + "
" @@ -3250,14 +3262,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): html = get_default_video_info() visible= len(file_list) > 0 return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) -def expand_slist(slist, num_inference_steps ): - new_slist= [] - inc = len(slist) / num_inference_steps - pos = 0 - for i in range(num_inference_steps): - new_slist.append(slist[ int(pos)]) - pos += inc - return new_slist + def convert_image(image): from PIL import ImageOps @@ -7871,7 +7876,7 @@ def generate_info_tab(): def get_sorted_dropdown(dropdown_types): - families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 } + families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3, "unknown": 100 } dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types] dropdown_names = [ get_model_name(type) for type in dropdown_types]