mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 19:53:22 +00:00
added better missing model error handling
This commit is contained in:
parent
afbf94e44f
commit
bda410f367
@ -34,6 +34,15 @@ def seed_everything(seed: int):
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
def expand_slist(slist, num_inference_steps ):
|
||||
new_slist= []
|
||||
inc = len(slist) / num_inference_steps
|
||||
pos = 0
|
||||
for i in range(num_inference_steps):
|
||||
new_slist.append(slist[ int(pos)])
|
||||
pos += inc
|
||||
return new_slist
|
||||
|
||||
def update_loras_slists(trans, slists, num_inference_steps ):
|
||||
from mmgp import offload
|
||||
slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ]
|
||||
|
||||
31
wgp.py
31
wgp.py
@ -16,6 +16,7 @@ import json
|
||||
import wan
|
||||
from wan.utils import notification_sound
|
||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
|
||||
from wan.utils.utils import expand_slist, update_loras_slists
|
||||
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video
|
||||
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
|
||||
|
||||
@ -160,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
model_filename = state["model_filename"]
|
||||
model_type = state["model_type"]
|
||||
inputs = get_model_settings(state, model_type)
|
||||
|
||||
if model_choice != model_type or inputs ==None:
|
||||
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
|
||||
|
||||
@ -1740,6 +1742,8 @@ def get_model_type(model_filename):
|
||||
|
||||
def get_model_family(model_type):
|
||||
model_type = get_base_model_type(model_type)
|
||||
if model_type == None:
|
||||
return "unknown"
|
||||
if "hunyuan" in model_type :
|
||||
return "hunyuan"
|
||||
elif "ltxv" in model_type:
|
||||
@ -1799,7 +1803,8 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ):
|
||||
|
||||
def get_model_name(model_type, description_container = [""]):
|
||||
model_def = get_model_def(model_type)
|
||||
if model_def == None: raise Exception(f"Unknown model {model_type}")
|
||||
if model_def == None:
|
||||
return f"Unknown model {model_type}"
|
||||
model_name = model_def["name"]
|
||||
description = model_def["description"]
|
||||
description_container[0] = description
|
||||
@ -1832,7 +1837,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m
|
||||
if choices == None: raise Exception(f"Invalid Module Id '{model_type}'")
|
||||
else:
|
||||
model_def = models_def.get(model_type, None)
|
||||
if model_def == None: raise Exception(f"Unknown model type {model_type}")
|
||||
if model_def == None: return None
|
||||
URLs = model_def["URLs"]
|
||||
if isinstance(URLs, str):
|
||||
if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}")
|
||||
@ -2120,11 +2125,18 @@ model_types = models_def.keys()
|
||||
displayed_model_types= []
|
||||
for model_type in model_types:
|
||||
model_def = get_model_def(model_type)
|
||||
if not (model_def != None and model_def.get("visible", True) == False):
|
||||
if not model_def is None and model_def.get("visible", True):
|
||||
displayed_model_types.append(model_type)
|
||||
|
||||
|
||||
transformer_types = server_config.get("transformer_types", [])
|
||||
new_transformer_types = []
|
||||
for model_type in transformer_types:
|
||||
if get_model_def(model_type) == None:
|
||||
print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json")
|
||||
else:
|
||||
new_transformer_types.append(model_type)
|
||||
transformer_types = new_transformer_types
|
||||
transformer_type = server_config.get("last_model_type", None)
|
||||
advanced = server_config.get("last_advanced_choice", False)
|
||||
if args.advanced: advanced = True
|
||||
@ -2729,7 +2741,7 @@ def generate_header(model_type, compile, attention_mode):
|
||||
|
||||
description_container = [""]
|
||||
get_model_name(model_type, description_container)
|
||||
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)
|
||||
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
|
||||
description = description_container[0]
|
||||
header = "<DIV style='height:40px'>" + description + "</DIV>"
|
||||
|
||||
@ -3250,14 +3262,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
html = get_default_video_info()
|
||||
visible= len(file_list) > 0
|
||||
return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image)
|
||||
def expand_slist(slist, num_inference_steps ):
|
||||
new_slist= []
|
||||
inc = len(slist) / num_inference_steps
|
||||
pos = 0
|
||||
for i in range(num_inference_steps):
|
||||
new_slist.append(slist[ int(pos)])
|
||||
pos += inc
|
||||
return new_slist
|
||||
|
||||
def convert_image(image):
|
||||
|
||||
from PIL import ImageOps
|
||||
@ -7871,7 +7876,7 @@ def generate_info_tab():
|
||||
|
||||
|
||||
def get_sorted_dropdown(dropdown_types):
|
||||
families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 }
|
||||
families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3, "unknown": 100 }
|
||||
dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types]
|
||||
dropdown_names = [ get_model_name(type) for type in dropdown_types]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user