From 8d12cf08b6170b6217f1c2153dd121aaf561f744 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Tue, 29 Jul 2025 03:52:07 +0200 Subject: [PATCH] just another version --- README.md | 14 ++ defaults/t2v_2_2.json | 24 +++ defaults/vace_14B_cocktail_2_2.json | 25 +++ defaults/vace_14B_cocktail_2_2_light.json | 19 ++ flux/flux_main.py | 39 +++- flux/sampling.py | 33 ++-- wan/any2video.py | 44 ++++- wan/diffusion_forcing.py | 4 +- wgp.py | 228 +++++++++++++++------- 9 files changed, 323 insertions(+), 107 deletions(-) create mode 100644 defaults/t2v_2_2.json create mode 100644 defaults/vace_14B_cocktail_2_2.json create mode 100644 defaults/vace_14B_cocktail_2_2_light.json diff --git a/README.md b/README.md index 70ed77b..8675ccc 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,20 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates +### July 29 2025: WanGP v7.3 : Wan 2.2 Preview + +Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. + +So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. + +However as I felt bad to deliver only half of the wares, I gave you instead ..... Wan 2.2 Vace Cocktail ! + +Very good surprise indeed, the loras and Vace mostly work with Wan 2.2 !!! I have made a light version of the cocktail that uses only half of the parameters of Wan 2.2, this version has exactly the same RAM requirements. Wan 2.1. Videos baked with half of the model are not so good but maybe they are better than Wan 2.1 . So you tell me if we should keep the light version. + +Probably Multitalk should work too, but I have a life to attend to so I will let you test. + +Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... + ### July 27 2025: WanGP v7.3 : Interlude While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. diff --git a/defaults/t2v_2_2.json b/defaults/t2v_2_2.json new file mode 100644 index 0000000..48c2408 --- /dev/null +++ b/defaults/t2v_2_2.json @@ -0,0 +1,24 @@ +{ + "model": + { + "name": "Wan2.2 Text2video 14B", + "architecture" : "t2v", + "description": "Wan 2.2 Text 2 Video model", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "switch_threshold" : 875, + "guidance_scale" : 4, + "guidance2_scale" : 3, + "flow_shift" : 12 + +} \ No newline at end of file diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json new file mode 100644 index 0000000..84fc989 --- /dev/null +++ b/defaults/vace_14B_cocktail_2_2.json @@ -0,0 +1,25 @@ +{ + "model": { + "name": "Wan2.2 Vace Cocktail 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [1, 0.2, 0.5, 0.5], + "group": "wan2_2" + }, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance_scale": 2, + "flow_shift": 2, + "switch_threshold" : 875 +} \ No newline at end of file diff --git a/defaults/vace_14B_cocktail_2_2_light.json b/defaults/vace_14B_cocktail_2_2_light.json new file mode 100644 index 0000000..d6d0330 --- /dev/null +++ b/defaults/vace_14B_cocktail_2_2_light.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Wan2.2 Vace Cocktail Light 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Only the high noise part of the v2.2 model is used to reduce RAM usage.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "loras": "vace_14B_cocktail_2_2", + "loras_multipliers": "vace_14B_cocktail_2_2", + "group": "wan2_2" + }, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance_scale": 2, + "flow_shift": 2 +} \ No newline at end of file diff --git a/flux/flux_main.py b/flux/flux_main.py index cbff9b9..17b5405 100644 --- a/flux/flux_main.py +++ b/flux/flux_main.py @@ -17,6 +17,20 @@ from flux.util import ( save_image, ) +from PIL import Image + +def stitch_images(img1, img2): + # Resize img2 to match img1's height + width1, height1 = img1.size + width2, height2 = img2.size + new_width2 = int(width2 * height1 / height2) + img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) + + stitched = Image.new('RGB', (width1 + new_width2, height1)) + stitched.paste(img1, (0, 0)) + stitched.paste(img2_resized, (width1, 0)) + return stitched + class model_factory: def __init__( self, @@ -72,6 +86,7 @@ class model_factory: callback = None, loras_slists = None, batch_size = 1, + video_prompt_type = "", **bbargs ): @@ -79,19 +94,30 @@ class model_factory: return None device="cuda" - if input_ref_images != None and len(input_ref_images) > 0: - image_ref = input_ref_images[0] - w, h = image_ref.size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0: + if "K" in video_prompt_type and False : + # image latents tiling method + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + else: + # image stiching method + stiched = input_ref_images[0] + if "K" in video_prompt_type : + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + + for new_img in input_ref_images[1:]: + stiched = stitch_images(stiched, new_img) + input_ref_images = [stiched] else: - image_ref = None + input_ref_images = None inp, height, width = prepare_kontext( t5=self.t5, clip=self.clip, prompt=input_prompt, ae=self.vae, - img_cond=image_ref, + img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, @@ -99,7 +125,6 @@ class model_factory: device=device, ) - inp.pop("img_cond_orig") timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) def unpack_latent(x): return unpack(x.float(), height, width) diff --git a/flux/sampling.py b/flux/sampling.py index 97cdc5c..7f14b09 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -214,7 +214,7 @@ def prepare_kontext( clip: HFEmbedder, prompt: str | list[str], ae: AutoEncoder, - img_cond: str, + img_cond_list: list, seed: int, device: torch.device, target_width: int | None = None, @@ -225,7 +225,10 @@ def prepare_kontext( if bs == 1 and not isinstance(prompt, str): bs = len(prompt) - if img_cond != None: + img_cond_seq = None + img_cond_seq_ids = None + + for cond_no, img_cond in enumerate(img_cond_list): width, height = img_cond.size aspect_ratio = width / height @@ -239,20 +242,19 @@ def prepare_kontext( img_cond = np.array(img_cond) img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 img_cond = rearrange(img_cond, "h w c -> 1 c h w") - img_cond_orig = img_cond.clone() - with torch.no_grad(): - img_cond = ae.encode(img_cond.to(device)) + img_cond_latents = ae.encode(img_cond.to(device)) - img_cond = img_cond.to(torch.bfloat16) - img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_cond_latents = img_cond_latents.to(torch.bfloat16) + img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: - img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs) + img_cond = None # image ids are the same as base image with the first dimension set to 1 # instead of 0 img_cond_ids = torch.zeros(height // 2, width // 2, 3) - img_cond_ids[..., 0] = 1 + img_cond_ids[..., 0] = cond_no + 1 img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) @@ -262,10 +264,10 @@ def prepare_kontext( if target_height is None: target_height = 8 * height img_cond_ids = img_cond_ids.to(device) - else: - img_cond = None - img_cond_ids = None - img_cond_orig = None + if cond_no == 0: + img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids + else: + img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1) img = get_noise( bs, @@ -277,9 +279,8 @@ def prepare_kontext( ) return_dict = prepare(t5, clip, img, prompt) - return_dict["img_cond_seq"] = img_cond - return_dict["img_cond_seq_ids"] = img_cond_ids - return_dict["img_cond_orig"] = img_cond_orig + return_dict["img_cond_seq"] = img_cond_seq + return_dict["img_cond_seq_ids"] = img_cond_seq_ids return return_dict, target_height, target_width diff --git a/wan/any2video.py b/wan/any2video.py index 951d274..20e3756 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -78,6 +78,8 @@ class WanAny2V: self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype self.model_def = model_def + self.model2 = None + self.transformer_switch = model_def.get("URLs2", None) is not None self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, @@ -101,24 +103,31 @@ class WanAny2V: vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, device=self.device) - xmodel_filename = "c:/temp/wan2.1_text2video_1.3B_bf16.safetensors" # config_filename= "configs/t2v_1.3B.json" # import json # with open(config_filename, 'r', encoding='utf-8') as f: # config = json.load(f) # sd = safetensors2.torch_load_file(xmodel_filename) - # model_filename = "c:/temp/vace1_3B.safetensors" + # model_filename = "c:/temp/wan2.2t2v/high/diffusion_pytorch_model-00001-of-00006.safetensors" base_config_file = f"configs/{base_model_type}.json" forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename + model_filename2 = None + if self.transformer_switch: + model_filename2 = model_filename[1:] + model_filename = model_filename[:1] + model_filename[2:] self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if model_filename2 is not None: + self.model2 = offload.fast_load_transformers_model(model_filename2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + # self.model = offload.load_model_data(self.model, xmodel_filename ) # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd) - # offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.2_text2video_14B_high_mbf16.safetensors", config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) self.model.eval().requires_grad_(False) if save_quantized: from wgp import save_quantized_model @@ -136,7 +145,8 @@ class WanAny2V: seq_len=32760, keep_last=True) - self.adapt_vace_model() + self.adapt_vace_model(self.model) + if self.model2 is not None: self.adapt_vace_model(self.model2) self.num_timesteps = 1000 self.use_timestep_transform = True @@ -353,6 +363,8 @@ class WanAny2V: sample_solver='unipc', sampling_steps=50, guide_scale=5.0, + guide2_scale = 5.0, + switch_threshold = 0, n_prompt="", seed=-1, callback = None, @@ -384,6 +396,7 @@ class WanAny2V: color_correction_strength = 1, prefix_frames_count = 0, image_mode = 0, + **bbargs ): @@ -699,9 +712,18 @@ class WanAny2V: apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) + + guidance_switch_done = False + # denoising + trans = self.model for i, t in enumerate(tqdm(timesteps)): - offload.set_step_no_for_lora(self.model, i) + if not guidance_switch_done and t <= switch_threshold: + guide_scale = guide2_scale + if self.model2 is not None: trans = self.model2 + guidance_switch_done = True + + offload.set_step_no_for_lora(trans, i) timestep = torch.stack([t]) kwargs.update({"t": timestep, "current_step": i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None @@ -760,7 +782,7 @@ class WanAny2V: } if joint_pass and guide_scale > 1: - ret_values = self.model( **gen_args , **kwargs) + ret_values = trans( **gen_args , **kwargs) if self._interrupt: return None else: @@ -768,7 +790,7 @@ class WanAny2V: ret_values = [None] * size for x_id in range(size): sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() } - ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0] + ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0] if self._interrupt: return None sub_gen_args = None @@ -870,8 +892,7 @@ class WanAny2V: return { "x" : videos, "latent_slice" : latent_slice } return videos - def adapt_vace_model(self): - model = self.model + def adapt_vace_model(self, model): modules_dict= { k: m for k, m in model.named_modules()} for model_layer, vace_layer in model.vace_layers_mapping.items(): module = modules_dict[f"vace_blocks.{vace_layer}"] @@ -880,4 +901,7 @@ class WanAny2V: delattr(model, "vace_blocks") def query_model_def(model_type, model_def): - return None \ No newline at end of file + if "URLs2" in model_def: + return { "no_steps_skipping":True} + else: + return None \ No newline at end of file diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index 8155023..d477402 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -211,7 +211,7 @@ class DTT2V: guide_scale: float = 5.0, seed: float = 0.0, overlap_noise: int = 0, - ar_step: int = 5, + model_mode: int = 5, causal_block_size: int = 5, causal_attention: bool = True, fps: int = 24, @@ -231,7 +231,7 @@ class DTT2V: if frame_num > 1: frame_num = max(17, frame_num) # must match causal_block_size for value of 5 frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) - + ar_step = model_mode if ar_step == 0: causal_block_size = 1 causal_attention = False diff --git a/wgp.py b/wgp.py index 1902359..712222a 100644 --- a/wgp.py +++ b/wgp.py @@ -52,7 +52,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.1" WanGP_version = "7.3" -settings_version = 2.22 +settings_version = 2.23 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -178,6 +178,7 @@ def process_prompt_and_add_tasks(state, model_choice): return get_queue_table(queue) model_def = get_model_def(model_type) image_outputs = model_def.get("image_outputs", False) + no_steps_skipping = model_def.get("no_steps_skipping", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -278,8 +279,12 @@ def process_prompt_and_add_tasks(state, model_choice): skip_steps_cache_type= inputs["skip_steps_cache_type"] MMAudio_setting = inputs["MMAudio_setting"] image_mode = inputs["image_mode"] + switch_threshold = inputs["switch_threshold"] - + if no_steps_skipping: skip_steps_cache_type = "" + if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0: + gr.Info("Steps skipping is not yet supported if Switch Threshold is not null") + return if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") return @@ -495,7 +500,8 @@ def process_prompt_and_add_tasks(state, model_choice): "denoising_strength": denoising_strength, "image_prompt_type": image_prompt_type, "video_prompt_type": video_prompt_type, - "audio_prompt_type": audio_prompt_type, + "audio_prompt_type": audio_prompt_type, + "skip_steps_cache_type": skip_steps_cache_type } if inputs["multi_prompts_gen_type"] == 0: @@ -1691,7 +1697,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) -families_infos = {"wan":(0, "Wan2.1"), "ltxv":(1, "LTX Video"), "hunyuan":(2, "Hunyuan Video"), "flux":(3, "Flux 1"), "unknown": (100, "Unknown") } +families_infos = {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2"), "ltxv":(10, "LTX Video"), "hunyuan":(20, "Hunyuan Video"), "flux":(30, "Flux 1"), "unknown": (100, "Unknown") } models_def = {} @@ -1764,15 +1770,22 @@ def get_model_type(model_filename): return None # raise Exception("Unknown model:" + model_filename) -def get_model_family(model_type): - model_type = get_base_model_type(model_type) - if model_type == None: +def get_model_family(model_type, for_ui = False): + base_model_type = get_base_model_type(model_type) + if base_model_type is None: return "unknown" - if "hunyuan" in model_type : + + if for_ui : + model_def = get_model_def(model_type) + model_family = model_def.get("group", None) + if model_family is not None and model_family in families_infos: + return model_family + + if "hunyuan" in base_model_type : return "hunyuan" - elif "ltxv" in model_type: + elif "ltxv" in base_model_type: return "ltxv" - elif "flux" in model_type: + elif "flux" in base_model_type: return "flux" else: return "wan" @@ -1855,17 +1868,19 @@ def get_model_recursive_prop(model_type, prop = "URLs", return_list = True, sta raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]): +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, submodel_no = 1, stack=[]): if is_module: choices = modules_files.get(model_type, None) if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") else: + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" + model_def = models_def.get(model_type, None) if model_def == None: return "" - URLs = model_def["URLs"] + URLs = model_def[key_name] if isinstance(URLs, str): - if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") - return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs]) + if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") + return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) else: choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] if len(quantization) == 0: @@ -1953,6 +1968,8 @@ def fix_settings(model_type, ui_defaults): if model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") + if model_type in ["flux"] and video_settings_version < 2.23: + video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI") remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1) if video_settings_version < 2.22: @@ -2037,7 +2054,7 @@ def get_default_settings(model_type): }) if model_def.get("reference_image", False): ui_defaults.update({ - "video_prompt_type": "I", + "video_prompt_type": "KI", }) elif base_model_type in ["sky_df_1.3B", "sky_df_14B"]: ui_defaults.update({ @@ -2124,9 +2141,9 @@ def get_model_query_handler(model_type): model_family= get_model_family(base_model_type) if model_family == "wan": if base_model_type in ("sky_df_1.3B", "sky_df_14B"): - from wan.any2video import query_model_def - else: from wan.diffusion_forcing import query_model_def + else: + from wan.any2video import query_model_def elif model_family == "hunyuan": from hyvideo.hunyuan import query_model_def elif model_family == "ltxv": @@ -2340,7 +2357,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename, model_type): +def download_models(model_filename, model_type, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2397,23 +2414,27 @@ def download_models(model_filename, model_type): model_family = get_model_family(model_type) model_def = get_model_def(model_type) - if model_def != None and not model_type in modules_files: + + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" + if not model_type in modules_files: if not os.path.isfile(model_filename ): - URLs = get_model_recursive_prop(model_type, "URLs", return_list= False) - if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after - use_url = model_filename - for url in URLs: - if os.path.basename(model_filename) in url: - use_url = url - break - if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") - try: - download_file(use_url, model_filename) - except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") - model_filename = None + URLs = get_model_recursive_prop(model_type, key_name, return_list= False) + if isinstance(URLs, str): + raise Exception("Missing model " + URLs) + use_url = model_filename + for url in URLs: + if os.path.basename(model_filename) in url: + use_url = url + break + if not url.startswith("http"): + raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") + try: + download_file(use_url, model_filename) + except Exception as e: + if os.path.isfile(model_filename): os.remove(model_filename) + raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") + + model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) for url in preload_URLs: @@ -2609,6 +2630,9 @@ def load_wan_model(model_filename, model_type, base_model_type, model_def, quant ) pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } + if wan_model.model2 is not None: + pipe["transformer2"] = wan_model.model2 + if hasattr(wan_model, "clip"): pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe @@ -2689,9 +2713,16 @@ def load_hunyuan_model(model_filename, model_type = None, base_model_type = Non return hunyuan_model, pipe -def get_transformer_model(model): +def get_transformer_model(model, submodel_no = 1): + if submodel_no > 1: + model_key = f"model{submodel_no}" + if not hasattr(model, model_key): return None + if hasattr(model, "model"): - return model.model + if submodel_no > 1: + return getattr(model, f"model{submodel_no}") + else: + return model.model elif hasattr(model, "transformer"): return model.transformer else: @@ -2705,6 +2736,10 @@ def load_models(model_type): preload =int(args.preload) save_quantized = args.save_quantized and model_def != None model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) + if "URLs2" in model_def: + model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! + else: + model_filename2 = None modules = get_model_recursive_prop(model_type, "modules", return_list= True) if save_quantized and "quanto" in model_filename: save_quantized = False @@ -2726,7 +2761,9 @@ def load_models(model_type): preload = server_config.get("preload_in_VRAM", 0) model_file_list = [model_filename] model_type_list = [model_type] - new_transformer_filename = model_file_list[-1] + if model_filename2 != None: + model_file_list += [model_filename2] + model_type_list += [model_type] for module_type in modules: model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True)) model_type_list.append(module_type) @@ -2750,14 +2787,22 @@ def load_models(model_type): elif model_family == "hunyuan": wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) else: - raise Exception(f"Model '{new_transformer_filename}' not supported.") - wan_model._model_file_name = new_transformer_filename + raise Exception(f"Model '{model_filename}' not supported.") kwargs = { "extraModelsToQuantize": None } + loras_transformer = ["transformer"] if profile in (2, 4, 5): - kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } + budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } + if "transformer2" in pipe: + budgets["transformer2"] = 100 if preload == 0 else preload + loras_transformer += ["transformer2"] + kwargs["budgets"] = budgets elif profile == 3: kwargs["budgets"] = { "*" : "70%" } - + + if "transformer2" in pipe and profile in [2,4]: + kwargs["pinnedMemory"] = ["transformer", "transformer2"] + + global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer if server_config.get("enhancer_enabled", 0) == 1: from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) @@ -2777,21 +2822,21 @@ def load_models(model_type): prompt_enhancer_llm_tokenizer = None - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type - return wan_model, offloadobj, pipe["transformer"] + return wan_model, offloadobj if not "P" in preload_model_policy: wan_model, offloadobj, transformer = None, None, None reload_needed = True else: - wan_model, offloadobj, transformer = load_models(transformer_type) + wan_model, offloadobj = load_models(transformer_type) if check_loras: + transformer = get_transformer_model(wan_model) setup_loras(transformer_type, transformer, get_lora_dir(transformer_type), "", None) exit() - del transformer gen_in_progress = False @@ -2929,6 +2974,7 @@ def apply_changes( state, model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ): + model_family = gr.Dropdown() model_choice = gr.Dropdown() else: reload_needed = True @@ -3303,7 +3349,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): values += ["unipc" if len(video_sample_solver) ==0 else video_sample_solver] labels += ["Sampler Solver"] values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] - labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"] + labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"] video_negative_prompt = configs.get("negative_prompt", "") if len(video_negative_prompt) > 0: values += [video_negative_prompt] @@ -3973,6 +4019,8 @@ def generate_video( force_fps, num_inference_steps, guidance_scale, + guidance2_scale, + switch_threshold, audio_guidance_scale, flow_shift, sample_solver, @@ -4094,7 +4142,7 @@ def generate_video( offloadobj = None gc.collect() send_cmd("status", f"Loading model {get_model_name(model_type)}...") - wan_model, offloadobj, trans = load_models(model_type) + wan_model, offloadobj = load_models(model_type) send_cmd("status", "Model loaded") reload_needed= False @@ -4120,6 +4168,7 @@ def generate_video( VAE_tile_size = wan_model.vae.get_VAE_tile_size(vae_config, device_mem_capacity, server_config.get("vae_precision", "16") == "32") trans = get_transformer_model(wan_model) + trans2 = get_transformer_model(wan_model, 2) audio_sampling_rate = 16000 base_model_type = get_base_model_type(model_type) @@ -4179,11 +4228,15 @@ def generate_video( loras_selected = transformer_loras_filenames + loras_selected loras_list_mult_choices_nums = transformer_loras_multipliers + loras_list_mult_choices_nums loras_slists = transformer_loras_multipliers + loras_slists - offload.load_loras_into_model(trans, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) - errors = trans._loras_errors - if len(errors) > 0: - error_files = [msg for _ , msg in errors] - raise gr.Error("Error while loading Loras: " + ", ".join(error_files)) + trans_list = [trans] + if trans2 is not None: trans_list += [trans2] + for trans_item in trans_list: + offload.load_loras_into_model(trans_item, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) + errors = trans._loras_errors + if len(errors) > 0: + error_files = [msg for _ , msg in errors] + raise gr.Error("Error while loading Loras: " + ", ".join(error_files)) + trans_item = trans_list = None seed = None if seed == -1 else seed # negative_prompt = "" # not applicable in the inference original_filename = model_filename @@ -4243,7 +4296,7 @@ def generate_video( any_background_ref = False outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] - if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): + if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace or flux): frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else [] frames_positions_list = frames_positions_list[:len(image_refs)] nb_frames_positions = len(frames_positions_list) @@ -4265,11 +4318,13 @@ def generate_video( send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from wan.utils.utils import resize_and_remove_background - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later + image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux) ) # no fit for vace ref images as it is done later update_task_thumbnails(task, locals()) send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 trans.enable_cache = None if len(skip_steps_cache_type) == 0 else skip_steps_cache_type + if trans2 is not None: + trans2.enable_cache = None if trans.enable_cache != None: trans.cache_multiplier = skip_steps_multiplier @@ -4660,6 +4715,8 @@ def generate_video( sample_solver=sample_solver, sampling_steps=num_inference_steps, guide_scale=guidance_scale, + guide2_scale = guidance2_scale, + switch_threshold = switch_threshold, embedded_guidance_scale=embedded_guidance_scale, n_prompt=negative_prompt, seed=seed, @@ -4680,7 +4737,7 @@ def generate_video( audio_scale= audio_scale, audio_context_lens= audio_context_lens, context_scale = context_scale, - ar_step = model_mode, #5 + model_mode = model_mode, causal_block_size = 5, causal_attention = True, fps = fps, @@ -4698,6 +4755,7 @@ def generate_video( NAG_alpha = NAG_alpha, speakers_bboxes =speakers_bboxes, image_mode = image_mode, + video_prompt_type= video_prompt_type, offloadobj = offloadobj, ) except Exception as e: @@ -4915,6 +4973,9 @@ def generate_video( seed = set_seed(-1) clear_status(state) offload.unload_loras_from_model(trans) + if not trans2 is None: + offload.unload_loras_from_model(trans2) + if len(control_audio_tracks) > 0: cleanup_temp_audio_files(control_audio_tracks) @@ -5793,6 +5854,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None ltxv = base_model_type in ["ltxv_13B"] recammaster = base_model_type in ["recam_1.3B"] phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] + flux = base_model_type in ["flux"] hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"] model_family = get_model_family(base_model_type) if target == "settings": @@ -5818,7 +5880,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not server_config.get("enhancer_enabled", 0) == 1: pop += ["prompt_enhancer"] - if not recammaster and not diffusion_forcing: + if not recammaster and not diffusion_forcing and not flux: pop += ["model_mode"] if not vace and not phantom and not hunyuan_video_custom: @@ -5837,14 +5899,14 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]: pop += ["audio_guidance_scale", "speakers_locations"] - if not model_family in ["hunyuan", "flux"]: + if not model_family in ["hunyuan", "flux"] or model_def.get("no_guidance", False): pop += ["embedded_guidance_scale"] if not model_family in ["hunyuan", "wan"]: pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"] - if model_def.get("no_guidance", False) or ltxv: - pop += ["guidance_scale", "audio_guidance_scale", "embedded_guidance_scale"] + if model_def.get("no_guidance", False) or ltxv or model_family in ["hunyuan", "flux"] : + pop += ["guidance_scale", "guidance2_scale", "switch_threshold", "audio_guidance_scale"] if model_def.get("image_outputs", False) or ltxv: pop += ["flow_shift"] @@ -6171,6 +6233,8 @@ def save_inputs( force_fps, num_inference_steps, guidance_scale, + guidance2_scale, + switch_threshold, audio_guidance_scale, flow_shift, sample_solver, @@ -6334,7 +6398,7 @@ def change_model(state, model_choice): model_filename = get_model_filename(model_choice, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename last_model_per_family = state["last_model_per_family"] - last_model_per_family[get_model_family(model_choice)] = model_choice + last_model_per_family[get_model_family(model_choice, for_ui= True)] = model_choice server_config["last_model_per_family"] = last_model_per_family server_config["last_model_type"] = model_choice @@ -6366,7 +6430,7 @@ def preload_model_when_switching(state): gc.collect() model_filename = get_model_name(model_type) yield f"Loading model {model_filename}..." - wan_model, offloadobj, _ = load_models(model_type) + wan_model, offloadobj = load_models(model_type) yield f"Model loaded" reload_needed= False return @@ -6760,6 +6824,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) + model_reference_image = model_def.get("reference_image", False) + no_steps_skipping = model_def.get("no_steps_skipping", False) recammaster = base_model_type in ["recam_1.3B"] vace = test_vace_module(base_model_type) phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] @@ -6889,7 +6955,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non keep_frames_video_source = gr.Text(visible=False) any_video_source = False - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_def.get("reference_image", False) ) as video_prompt_column: + with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) any_control_video = True @@ -6947,7 +7013,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ], value=filter_letters(video_prompt_type_value, "PDSLCMUV"), label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, - ) + ) else: any_control_video = False any_control_image = False @@ -7012,6 +7078,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible = True, label="Reference Images", scale = 2 ) + + + elif flux and model_reference_image: + video_prompt_type_image_refs = gr.Dropdown( + choices=[ + ("None", ""), + ("Inject only People / Objects", "I"), + ("Inject Main Subject / Landscape and then People / Objects", "KI"), + ], + value=filter_letters(video_prompt_type_value, "KFI"), + visible = True, + show_label=False, + label="Reference Images Combination Method", scale = 2 + ) else: video_prompt_type_image_refs = gr.Dropdown( choices=[ ("Start / Ref Image", "I")], @@ -7052,11 +7132,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) remove_background_images_ref = gr.Dropdown( choices=[ - ("Keep Backgrounds behind People / Objects", 0), - ("Remove Backgrounds behind People / Objects", 1), + ("Keep Backgrounds behind all Reference Images", 0), + ("Remove Backgrounds only behind People / Objects", 1), ], value=ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux + label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar ) any_audio_voices_support = any_audio_track(base_model_type) @@ -7174,6 +7254,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=(fantasy or multitalk) and not no_guidance) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) + with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row2: + guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) + switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) + with gr.Row(visible = get_model_family(model_type) == "wan" and not diffusion_forcing ) as sample_solver_row: sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver",""), choices=[ @@ -7214,7 +7298,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Activated Loras" ) loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) - with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs)) as speed_tab: + with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs) and not no_steps_skipping) as speed_tab: with gr.Column(): gr.Markdown("Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.") gr.Markdown("Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.") @@ -7577,7 +7661,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column, + NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -8265,7 +8349,7 @@ def compact_name(family_name, model_name): return model_name def get_sorted_dropdown(dropdown_types, current_model_family): - models_families = [get_model_family(type) for type in dropdown_types] + models_families = [get_model_family(type, for_ui= True) for type in dropdown_types] families = {} for family in models_families: if family not in families: families[family] = 1 @@ -8285,7 +8369,7 @@ def generate_dropdown_model_list(current_model_type): dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types if current_model_type not in dropdown_types: dropdown_types.append(current_model_type) - current_model_family = get_model_family(current_model_type) + current_model_family = get_model_family(current_model_type, for_ui= True) sorted_familes, dropdown_choices = get_sorted_dropdown(dropdown_types, current_model_family) dropdown_families = gr.Dropdown( @@ -8308,12 +8392,12 @@ def generate_dropdown_model_list(current_model_type): def change_model_family(state, current_model_family): dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types current_family_name = families_infos[current_model_family][1] - models_families = [get_model_family(type) for type in dropdown_types] + models_families = [get_model_family(type, for_ui= True) for type in dropdown_types] dropdown_choices = [ (compact_name(current_family_name, get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family ] dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0]) last_model_per_family = state.get("last_model_per_family", {}) model_type = last_model_per_family.get(current_model_family, "") - if len(model_type) == "" or model_type not in displayed_model_types: model_type = dropdown_choices[0][1] + if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1] return gr.Dropdown(choices= dropdown_choices, value = model_type ) def set_new_tab(tab_state, new_tab_no):