added better missing model error handling

This commit is contained in:
deepbeepmeep 2025-07-16 00:47:23 +02:00
parent afbf94e44f
commit bda410f367
2 changed files with 27 additions and 13 deletions

View File

@ -34,6 +34,15 @@ def seed_everything(seed: int):
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def expand_slist(slist, num_inference_steps ):
new_slist= []
inc = len(slist) / num_inference_steps
pos = 0
for i in range(num_inference_steps):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def update_loras_slists(trans, slists, num_inference_steps ):
from mmgp import offload
slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ]

31
wgp.py
View File

@ -16,6 +16,7 @@ import json
import wan
from wan.utils import notification_sound
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import expand_slist, update_loras_slists
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
@ -160,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice):
model_filename = state["model_filename"]
model_type = state["model_type"]
inputs = get_model_settings(state, model_type)
if model_choice != model_type or inputs ==None:
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
@ -1740,6 +1742,8 @@ def get_model_type(model_filename):
def get_model_family(model_type):
model_type = get_base_model_type(model_type)
if model_type == None:
return "unknown"
if "hunyuan" in model_type :
return "hunyuan"
elif "ltxv" in model_type:
@ -1799,7 +1803,8 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ):
def get_model_name(model_type, description_container = [""]):
model_def = get_model_def(model_type)
if model_def == None: raise Exception(f"Unknown model {model_type}")
if model_def == None:
return f"Unknown model {model_type}"
model_name = model_def["name"]
description = model_def["description"]
description_container[0] = description
@ -1832,7 +1837,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m
if choices == None: raise Exception(f"Invalid Module Id '{model_type}'")
else:
model_def = models_def.get(model_type, None)
if model_def == None: raise Exception(f"Unknown model type {model_type}")
if model_def == None: return None
URLs = model_def["URLs"]
if isinstance(URLs, str):
if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}")
@ -2120,11 +2125,18 @@ model_types = models_def.keys()
displayed_model_types= []
for model_type in model_types:
model_def = get_model_def(model_type)
if not (model_def != None and model_def.get("visible", True) == False):
if not model_def is None and model_def.get("visible", True):
displayed_model_types.append(model_type)
transformer_types = server_config.get("transformer_types", [])
new_transformer_types = []
for model_type in transformer_types:
if get_model_def(model_type) == None:
print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json")
else:
new_transformer_types.append(model_type)
transformer_types = new_transformer_types
transformer_type = server_config.get("last_model_type", None)
advanced = server_config.get("last_advanced_choice", False)
if args.advanced: advanced = True
@ -2729,7 +2741,7 @@ def generate_header(model_type, compile, attention_mode):
description_container = [""]
get_model_name(model_type, description_container)
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
description = description_container[0]
header = "<DIV style='height:40px'>" + description + "</DIV>"
@ -3250,14 +3262,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
html = get_default_video_info()
visible= len(file_list) > 0
return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image)
def expand_slist(slist, num_inference_steps ):
new_slist= []
inc = len(slist) / num_inference_steps
pos = 0
for i in range(num_inference_steps):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def convert_image(image):
from PIL import ImageOps
@ -7871,7 +7876,7 @@ def generate_info_tab():
def get_sorted_dropdown(dropdown_types):
families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 }
families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3, "unknown": 100 }
dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types]
dropdown_names = [ get_model_name(type) for type in dropdown_types]