diff --git a/README.md b/README.md index f09a33b..1fd4a91 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,19 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates -### June 12 2025: WanGP v5.6 -👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add yourself the support for this model in WanGP by just creating Finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. +### June 12 2025: WanGP v6.0 +👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add ny yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. -To celebrate this new feature, I have provided 4 finetuned model definitions: +To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): - *Fast Hunyuan Video* : generate model t2v in only 6 steps - *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps - *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps -- *Vace FusioniX*: the ultimate Vace model, as it is a combo of Vace / AccVideo / CausVid ans other models and can generate high quality Wan Controled videos in only 10 steps + +One more thing... + +The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? + +You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. diff --git a/configs/fantasy.json b/configs/fantasy.json new file mode 100644 index 0000000..2c9b9b9 --- /dev/null +++ b/configs/fantasy.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "fantasytalking_dim": 2048 +} diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index dc15df2..ce532f9 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,22 @@ # Changelog ## 🔥 Latest News +### June 12 2025: WanGP v5.6 +👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. + +To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): +- *Fast Hunyuan Video* : generate model t2v in only 6 steps +- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps +- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps + +One more thing... + +The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? + +You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... + +Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. + ### June 11 2025: WanGP v5.5 👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ *Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index ef2eeb5..1fbfdbc 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -25,8 +25,8 @@ Here are steps: 3) Save this file in the subfolder **finetunes**. The name used for the file will be used as its id. It is a good practise to prefix the name of this file with the base model. For instance for a finetune named **Fast*** based on Hunyuan Text 2 Video model *hunyuan_t2v_fast.json*. In this example the Id is *hunyuan_t2v_fast*. 4) Restart WanGP -## Base Models Ids -A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are the Ids: +## Architecture Models Ids +A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are Architecture Ids: - *t2v*: Wan 2.1 Video text 2 - *i2v*: Wan 2.1 Video image 2 480p - *i2v_720p*: Wan 2.1 Video image 2 720p @@ -36,9 +36,10 @@ A finetune is derived from a base model and will inherit all the user interface ## The Model Subtree - *name* : name of the finetune used to select -- *base* : Id of the base model of the finetune (see previous section) +- *architecture* : architecture Id of the base model of the finetune (see previous section) - *description*: description of the finetune that will appear at the top - *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing. +- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. So far the only module supported is Vace 14B (its id is *vace_14B*). For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. - *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance) - *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model @@ -47,7 +48,7 @@ Example of **model** subtree "model": { "name": "Wan text2video FusioniX 14B", - "base" : "t2v", + "architecture" : "t2v", "description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", diff --git a/finetunes/hunyuan_t2v_accvideo.json b/finetunes/hunyuan_t2v_accvideo.json index 73fe516..2164744 100644 --- a/finetunes/hunyuan_t2v_accvideo.json +++ b/finetunes/hunyuan_t2v_accvideo.json @@ -1,7 +1,7 @@ { "model": { "name": "Hunyuan AccVideo 720p 13B", - "base": "hunyuan", + "architecture": "hunyuan", "description": " AccVideo is a novel efficient distillation method to accelerate video diffusion models with synthetic datset. Our method is 8.5x faster than HunyuanVideo.", "URLs": [ "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/accvideo_hunyuan_video_720_quanto_int8.safetensors" diff --git a/finetunes/hunyuan_t2v_fast.json b/finetunes/hunyuan_t2v_fast.json index d0c3594..a7721fd 100644 --- a/finetunes/hunyuan_t2v_fast.json +++ b/finetunes/hunyuan_t2v_fast.json @@ -1,7 +1,7 @@ { "model": { "name": "Hunyuan Fast Video 720p 13B", - "base": "hunyuan", + "architecture": "hunyuan", "description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.", "URLs": [ "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/fast_hunyuan_video_720_quanto_int8.safetensors" diff --git a/finetunes/t2v_fusionix.json b/finetunes/t2v_fusionix.json index 5603a09..e159398 100644 --- a/finetunes/t2v_fusionix.json +++ b/finetunes/t2v_fusionix.json @@ -2,7 +2,7 @@ "model": { "name": "Wan text2video FusioniX 14B", - "base" : "t2v", + "architecture" : "t2v", "description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", diff --git a/finetunes/vace_14B_fusionix.json b/finetunes/vace_14B_fusionix.json index 7fff0ba..3d1ee49 100644 --- a/finetunes/vace_14B_fusionix.json +++ b/finetunes/vace_14B_fusionix.json @@ -2,12 +2,13 @@ "model": { "name": "Vace FusioniX 14B", - "base" : "vace_14B", + "architecture" : "vace_14B", + "modules" : ["vace_14B"], "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_mfp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_quanto_mfp16_int8.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_quanto_mbf16_int8.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors" ], "auto_quantize": true }, diff --git a/hyvideo/hunyuan.py b/hyvideo/hunyuan.py index b307851..ef35894 100644 --- a/hyvideo/hunyuan.py +++ b/hyvideo/hunyuan.py @@ -387,7 +387,7 @@ class Inference(object): # model = Inference.load_state_dict(args, model, model_filepath) # model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt" - offload.load_model_data(model, model_filepath, quantizeTransformer = quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning) + offload.load_model_data(model, model_filepath, do_quantize= quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning) pass # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors") # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True) diff --git a/hyvideo/modules/models.py b/hyvideo/modules/models.py index 65075cc..626748b 100644 --- a/hyvideo/modules/models.py +++ b/hyvideo/modules/models.py @@ -493,8 +493,8 @@ class MMSingleStreamBlock(nn.Module): return img, txt class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): - def preprocess_loras(self, model_filename, sd): - if not "i2v" in model_filename: + def preprocess_loras(self, model_type, sd): + if model_type != "i2v" : return sd new_sd = {} for k,v in sd.items(): diff --git a/requirements.txt b/requirements.txt index c6f5045..ced4da7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ gradio==5.23.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.4.8 +mmgp==3.4.9 peft==0.14.0 mutagen pydantic==2.10.6 diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index 28fc93b..7c93216 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -64,7 +64,8 @@ class DTT2V: # model_filename = "model.safetensors" # model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors" base_config_file = f"configs/{base_model_type}.json" - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json") + forcedConfigPath = base_config_file if len(model_filename) > 1 else None + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False , forcedConfigPath=forcedConfigPath) # offload.load_model_data(self.model, "recam.ckpt") # self.model.cpu() # dtype = torch.float16 diff --git a/wan/image2video.py b/wan/image2video.py index 46f05a8..1e93e82 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -104,7 +104,8 @@ class WanI2V: # model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors" # dtype = torch.float16 base_config_file = f"configs/{base_model_type}.json" - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath= base_config_file) #, forcedConfigPath= "c:/temp/i2v720p/config.json") + forcedConfigPath = base_config_file if len(model_filename) > 1 else None + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath= base_config_file, forcedConfigPath= forcedConfigPath) self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json") diff --git a/wan/modules/model.py b/wan/modules/model.py index 39f5796..a749a33 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -589,7 +589,7 @@ class MLPProj(torch.nn.Module): class WanModel(ModelMixin, ConfigMixin): - def preprocess_loras(self, model_filename, sd): + def preprocess_loras(self, model_type, sd): first = next(iter(sd), None) if first == None: @@ -634,7 +634,7 @@ class WanModel(ModelMixin, ConfigMixin): new_sd.update(new_alphas) sd = new_sd from wgp import test_class_i2v - if not test_class_i2v(model_filename): + if not test_class_i2v(model_type): new_sd = {} # convert loras for i2v to t2v for k,v in sd.items(): diff --git a/wan/text2video.py b/wan/text2video.py index edd2ed5..c3b2651 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -84,10 +84,11 @@ class WanT2V: from mmgp import offload # model_filename = "c:/temp/vace1.3/diffusion_pytorch_model.safetensors" # model_filename = "Vacefusionix_quanto_fp16_int8.safetensors" - # model_filename = "c:/temp/phantom/Phantom_Wan_14B-00001-of-00006.safetensors" - # config_filename= "c:/temp/phantom/config.json" + # model_filename = "c:/temp/t2v/diffusion_pytorch_model-00001-of-00006.safetensors" + # config_filename= "c:/temp/t2v/t2v.json" base_config_file = f"configs/{base_model_type}.json" - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file)#, forcedConfigPath= config_filename) + forcedConfigPath = base_config_file if len(model_filename) > 1 else None + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") # self.model.to(torch.bfloat16) # self.model.cpu() @@ -95,8 +96,8 @@ class WanT2V: # dtype = torch.bfloat16 # offload.load_model_data(self.model, "ckpts/Wan14BT2VFusioniX_fp16.safetensors") offload.change_dtype(self.model, dtype, True) - # offload.save_model(self.model, "wanfusionix_fp16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wanfusionix_quanto_fp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) self.model.eval().requires_grad_(False) if save_quantized: from wan.utils.utils import save_quantized_model diff --git a/wgp.py b/wgp.py index d4b00fb..f16f958 100644 --- a/wgp.py +++ b/wgp.py @@ -44,7 +44,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.4.9" -WanGP_version = "5.6" +WanGP_version = "6.0" prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None from importlib.metadata import version @@ -136,13 +136,13 @@ def process_prompt_and_add_tasks(state, model_choice): state["validate_success"] = 0 model_filename = state["model_filename"] - - model_type = get_model_type(model_filename) + model_type = state["model_type"] inputs = state.get(model_type, None) if model_choice != model_type or inputs ==None: raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") inputs["state"] = state + inputs["model_type"] = model_type inputs.pop("lset_name") if inputs == None: gr.Warning("Internal state error: Could not retrieve inputs for the model.") @@ -161,7 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice): return inputs["model_filename"] = model_filename - model_filename = get_base_model_filename(model_filename) + model_filename = get_model_filename(get_base_model_type(model_type)) prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] if len(prompts) ==0: @@ -173,7 +173,7 @@ def process_prompt_and_add_tasks(state, model_choice): resolution = inputs["resolution"] width, height = resolution.split("x") width, height = int(width), int(height) - # if test_class_i2v(model_filename): + # if test_class_i2v(model_type): # if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: # gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") # return @@ -373,7 +373,7 @@ def process_prompt_and_add_tasks(state, model_choice): } inputs.update(extra_inputs) add_video_task(**inputs) - elif test_class_i2v(model_filename) : + elif test_class_i2v(model_type) : image_prompt_type = inputs["image_prompt_type"] image_start = inputs["image_start"] @@ -1435,10 +1435,9 @@ def _parse_args(): return args -def get_lora_dir(model_filename): - model_filename = get_base_model_filename(model_filename) - model_family = get_model_family(model_filename) - i2v = test_class_i2v(model_filename) +def get_lora_dir(model_type): + model_family = get_model_family(model_type) + i2v = test_class_i2v(model_type) if model_family == "wan": lora_dir =args.lora_dir if i2v and len(lora_dir)==0: @@ -1447,7 +1446,7 @@ def get_lora_dir(model_filename): return lora_dir root_lora_dir = "loras_i2v" if i2v else "loras" - if "1.3B" in model_filename : + if "1.3B" in model_type : lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") if os.path.isdir(lora_dir_1_3B ): return lora_dir_1_3B @@ -1544,19 +1543,20 @@ else: # Deprecated models for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", -"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors" +"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors", +"wan2.1_text2video_1.3B_bf16.safetensors", "wan2.1_text2video_14B_bf16.safetensors", "wan2.1_text2video_14B_quanto_int8.safetensors", +"wan2.1_Vace_1.3B_mbf16.safetensors", "wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors" ]: if Path(os.path.join("ckpts" , path)).is_file(): print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) finetunes = {} -finetunes_filemap = {} -wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", +wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_mbf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", - "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_quanto_mbf16_int8.safetensors", + "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_moviigen1.1_14B_mbf16.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", ] @@ -1574,13 +1574,17 @@ hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_vid ] transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan_choices -def get_dependent_models(model_filename, quantization, dtype_policy ): - if "fantasy" in model_filename: - return [get_model_filename("i2v_720p", quantization, dtype_policy)] - elif "ltxv_0.9.7_13B_distilled_lora128" in model_filename: - return [get_model_filename("ltxv_13B", quantization, dtype_policy)] +def get_dependent_models(model_type, quantization, dtype_policy ): + if model_type == "fantasy": + dependent_model_type = "i2v_720p" + elif model_type == "ltxv_13B_distilled": + dependent_model_type = "ltxv_13B" + elif model_type == "vace_14B": + dependent_model_type = "t2v" else: - return [] + return [], [] + return [get_model_filename(dependent_model_type, quantization, dtype_policy)], [] + model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", @@ -1590,51 +1594,47 @@ model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", " "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit", "hunyuan_avatar" : "hunyuan_video_avatar" } -def get_model_finetune_def(model_filename): - model_type = finetunes_filemap.get(model_filename, None ) - if model_type == None: - return None +def get_model_finetune_def(model_type): return finetunes.get(model_type, None ) -def get_base_model_filename(model_filename): - finetune_def = get_model_finetune_def(model_filename) +def get_base_model_type(model_type): + finetune_def = get_model_finetune_def(model_type) if finetune_def == None: - return model_filename + return model_type else: - return finetune_def["base_filename"] + return finetune_def["architecture"] + def get_model_type(model_filename): - model_type = finetunes_filemap.get(model_filename, None ) - if model_type != None: - return model_type for model_type, signature in model_signatures.items(): if signature in model_filename: - return model_type - raise Exception("Unknown model:" + model_filename) + return model_type + return None + # raise Exception("Unknown model:" + model_filename) -def get_model_family(model_filename): - finetune_def = get_model_finetune_def(model_filename) - if finetune_def != None: - return finetune_def["model_family"] - if "wan" in model_filename or "sky" in model_filename: - return "wan" - elif "ltxv" in model_filename: - return "ltxv" - elif "hunyuan" in model_filename: +def get_model_family(model_type): + model_type = get_base_model_type(model_type) + if "hunyuan" in model_type : return "hunyuan" + elif "ltxv" in model_type: + return "ltxv" else: - raise Exception(f"Unknown model family for model'{model_filename}'") - -def test_class_i2v(model_filename): - model_filename = get_base_model_filename(model_filename) - return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename or "hunyuan_video_i2v" in model_filename + return "wan" -def get_model_name(model_filename, description_container = [""]): - finetune_def = get_model_finetune_def(model_filename) + +def test_class_i2v(model_type): + model_type = get_base_model_type(model_type) + return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ] + +def get_model_name(model_type, description_container = [""]): + finetune_def = get_model_finetune_def(model_type) if finetune_def != None: model_name = finetune_def["name"] description = finetune_def["description"] - elif "Fun" in model_filename: + description_container[0] = description + return model_name + model_filename = get_model_filename(model_type) + if "Fun" in model_filename: model_name = "Fun InP image2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model." @@ -1725,7 +1725,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): if len(quantization) == 0: quantization = "bf16" - model_family = get_model_family(choices[0]) + model_family = get_model_family(model_type) dtype = get_transformer_dtype(model_family, dtype_policy) if len(choices) <= 1: raw_filename = choices[0] @@ -1743,7 +1743,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): else: raw_filename = choices[0] - if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def != None : + if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def == None : if "quanto_int8" in raw_filename: raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8") elif "quanto_bf16_int8" in raw_filename: @@ -1767,19 +1767,19 @@ def get_transformer_dtype(model_family, transformer_dtype_policy): else: return torch.bfloat16 -def get_settings_file_name(model_filename): - return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json") +def get_settings_file_name(model_type): + return os.path.join(args.settings, model_type + "_settings.json") -def get_default_settings(filename): +def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." else: return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." - i2v = test_class_i2v(filename) - defaults_filename = get_settings_file_name(filename) + i2v = test_class_i2v(model_type) + defaults_filename = get_settings_file_name(model_type) if not Path(defaults_filename).is_file(): - finetune_def = get_model_finetune_def(filename) + finetune_def = get_model_finetune_def(model_type) if finetune_def != None: ui_defaults = finetune_def["settings"] if len(ui_defaults.get("prompt","")) == 0: @@ -1787,7 +1787,7 @@ def get_default_settings(filename): else: ui_defaults = { "prompt": get_default_prompt(i2v), - "resolution": "1280x720" if "720p" in filename else "832x480", + "resolution": "1280x720" if "720" in model_type else "832x480", "video_length": 81, "num_inference_steps": 30, "seed": -1, @@ -1796,7 +1796,7 @@ def get_default_settings(filename): "guidance_scale": 5.0, "embedded_guidance_scale" : 6.0, "audio_guidance_scale": 5.0, - "flow_shift": get_default_flow(filename, i2v), + "flow_shift": 7.0 if not "720" in model_type and i2v else 5.0, "negative_prompt": "", "activated_loras": [], "loras_multipliers": "", @@ -1809,25 +1809,25 @@ def get_default_settings(filename): "slg_end_perc": 90 } - if get_model_type(filename) in ("hunyuan","hunyuan_i2v"): + if model_type in ("hunyuan","hunyuan_i2v"): ui_defaults.update({ "guidance_scale": 7.0, }) - if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + if model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): ui_defaults.update({ "guidance_scale": 6.0, "flow_shift": 8, "sliding_window_discard_last_frames" : 0, - "resolution": "1280x720" if "720p" in filename else "960x544", - "sliding_window_size" : 121 if "720p" in filename else 97, + "resolution": "1280x720" if "720" in model_type else "960x544", + "sliding_window_size" : 121 if "720" in model_type else 97, "RIFLEx_setting": 2, "guidance_scale": 6, "flow_shift": 8, }) - if get_model_type(filename) in ("phantom_1.3B", "phantom_14B"): + if model_type in ("phantom_1.3B", "phantom_14B"): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, @@ -1835,20 +1835,20 @@ def get_default_settings(filename): # "resolution": "1280x720" }) - elif get_model_type(filename) in ("hunyuan_custom"): + elif model_type in ("hunyuan_custom"): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, "resolution": "1280x720", }) - elif get_model_type(filename) in ("hunyuan_custom_edit"): + elif model_type in ("hunyuan_custom_edit"): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "MV", "sliding_window_size": 129, }) - elif get_model_type(filename) in ("hunyuan_avatar"): + elif model_type in ("hunyuan_avatar"): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, @@ -1856,7 +1856,7 @@ def get_default_settings(filename): "video_length": 129, "video_prompt_type": "I", }) - elif get_model_type(filename) in ("vace_14B"): + elif model_type in ("vace_14B"): ui_defaults.update({ "sliding_window_discard_last_frames": 0, }) @@ -1898,13 +1898,7 @@ for file_path in finetunes_paths: finetune_def = json_def["model"] del json_def["model"] finetune_def["settings"] = json_def - base_filename = get_model_filename(finetune_def["base"]) - finetune_def["base_filename"] = base_filename - finetune_def["model_family"] = get_model_family(base_filename) finetunes[finetune_id] = finetune_def - for url in finetune_def["URLs"]: - url = url.split("/")[-1] - finetunes_filemap["ckpts/" + url] = finetune_id model_types += finetunes.keys() @@ -1919,7 +1913,6 @@ if args.fp16: transformer_dtype_policy = "fp16" if args.bf16: transformer_dtype_policy = "bf16" -transformer_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) text_encoder_quantization =server_config.get("text_encoder_quantization", "int8") attention_mode = server_config["attention_mode"] if len(args.attention)> 0: @@ -1943,29 +1936,28 @@ preload_model_policy = server_config.get("preload_model_policy", []) if args.t2v_14B or args.t2v: - transformer_filename = get_model_filename("t2v", transformer_quantization, transformer_dtype_policy) + transformer_type = "t2v" if args.i2v_14B or args.i2v: - transformer_filename = get_model_filename("i2v", transformer_quantization, transformer_dtype_policy) + transformer_type = "i2v" if args.t2v_1_3B: - transformer_filename = get_model_filename("t2v_1.3B", transformer_quantization, transformer_dtype_policy) + transformer_type = "t2v_1.3B" if args.i2v_1_3B: - transformer_filename = get_model_filename("fun_inp_1.3B", transformer_quantization, transformer_dtype_policy) + transformer_type = "fun_inp_1.3B" if args.vace_1_3B: - transformer_filename = get_model_filename("vace_1.3B", transformer_quantization, transformer_dtype_policy) + transformer_type = "vace_1.3B" only_allow_edit_in_advanced = False lora_preselected_preset = args.lora_preset -lora_preset_model = transformer_filename +lora_preset_model = transformer_type if args.compile: #args.fastest or compile="transformer" lock_ui_compile = True -model_filename = "" #attention_mode="sage" #attention_mode="sage2" #attention_mode="flash" @@ -1973,14 +1965,13 @@ model_filename = "" #attention_mode="xformers" # compile = "transformer" -def get_loras_preprocessor(transformer, model_filename): - model_filename = get_base_model_filename(model_filename) +def get_loras_preprocessor(transformer, model_type): preprocessor = getattr(transformer, "preprocess_loras", None) if preprocessor == None: return None def preprocessor_wrapper(sd): - return preprocessor(model_filename, sd) + return preprocessor(model_type, sd) return preprocessor_wrapper @@ -2015,7 +2006,7 @@ def get_hunyuan_text_encoder_filename(text_encoder_quantization): return text_encoder_filename -def download_models(transformer_filename): +def download_models(model_filename, model_type): def computeList(filename): if filename == None: return [] @@ -2058,22 +2049,22 @@ def download_models(transformer_filename): process_files_def(**enhancer_def) - model_family = get_model_family(transformer_filename) - finetune_def = get_model_finetune_def(transformer_filename) + model_family = get_model_family(model_type) + finetune_def = get_model_finetune_def(model_type) if finetune_def != None: from urllib.request import urlretrieve from wan.utils.utils import create_progress_hook - if not os.path.isfile(transformer_filename ): + if not os.path.isfile(model_filename ): for url in finetune_def["URLs"]: - if transformer_filename in url: + if model_filename in url: break if not url.startswith("http"): - raise Exception(f"Model '{transformer_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") try: - urlretrieve(url,transformer_filename, create_progress_hook(transformer_filename)) + urlretrieve(url,model_filename, create_progress_hook(model_filename)) except Exception as e: - if os.path.isfile(filename): os.remove(transformer_filename) - raise Exception(f"URL '{url}' is invalid for Model '{transformer_filename}' : {str(e)}'") + if os.path.isfile(model_filename): os.remove(model_filename) + raise Exception(f"URL '{url}' is invalid for Model '{model_filename}' : {str(e)}'") for url in finetune_def.get("preload_URLs", []): filename = "ckpts/" + url.split("/")[-1] if not os.path.isfile(filename ): @@ -2084,20 +2075,20 @@ def download_models(transformer_filename): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'") - transformer_filename = None + model_filename = None if model_family == "wan": text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) model_def = { "repoId" : "DeepBeepMeep/Wan2.1", "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], - "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(transformer_filename) ] + "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] } elif model_family == "ltxv": text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) model_def = { "repoId" : "DeepBeepMeep/LTX_Video", "sourceFolderList" : ["T5_xxl_1.1", "" ], - "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(transformer_filename) ] + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] } elif model_family == "hunyuan": text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) @@ -2108,13 +2099,13 @@ def download_models(transformer_filename): ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], ["detface.pt"], - [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(transformer_filename) + [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) ] } else: model_manager = get_model_manager(model_family) - model_def = model_manager.get_files_def(transformer_filename, text_encoder_quantization) + model_def = model_manager.get_files_def(model_filename, text_encoder_quantization) process_files_def(**model_def) @@ -2125,14 +2116,14 @@ offload.default_verboseLevel = verbose_level def sanitize_file_name(file_name, rep =""): return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep) -def extract_preset(model_filename, lset_name, loras): +def extract_preset(model_type, lset_name, loras): loras_choices = [] loras_choices_files = [] loras_mult_choices = "" prompt ="" full_prompt ="" lset_name = sanitize_file_name(lset_name) - lora_dir = get_lora_dir(model_filename) + lora_dir = get_lora_dir(model_type) if not lset_name.endswith(".lset"): lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" ) else: @@ -2166,7 +2157,7 @@ def extract_preset(model_filename, lset_name, loras): -def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): +def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] loras_names = [] default_loras_choices = [] @@ -2177,7 +2168,7 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, from pathlib import Path - lora_dir = get_lora_dir(model_filename) + lora_dir = get_lora_dir(model_type) if lora_dir != None : if not os.path.isdir(lora_dir): raise Exception("--lora-dir should be a path to a directory that contains Loras") @@ -2193,7 +2184,7 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets] if transformer !=None: - loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=get_loras_preprocessor(transformer, model_filename), split_linear_modules_map = split_linear_modules_map) #lora_multiplier, + loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=get_loras_preprocessor(transformer, model_type), split_linear_modules_map = split_linear_modules_map) #lora_multiplier, if len(loras) > 0: loras_names = [ Path(lora).stem for lora in loras ] @@ -2202,23 +2193,20 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")): raise Exception(f"Unknown preset '{lora_preselected_preset}'") default_lora_preset = lora_preselected_preset - default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(model_filename, default_lora_preset, loras) + default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(model_type, default_lora_preset, loras) if len(error) > 0: print(error[:200]) return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): - filename = model_filename[-1] - print(f"Loading '{filename}' model...") - - if test_class_i2v(model_filename[0]): + if test_class_i2v(base_model_type): cfg = WAN_CONFIGS['i2v-14B'] model_factory = wan.WanI2V else: cfg = WAN_CONFIGS['t2v-14B'] # cfg = WAN_CONFIGS['t2v-1.3B'] - if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + if base_model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): model_factory = wan.DTT2V else: model_factory = wan.WanT2V @@ -2241,9 +2229,7 @@ def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe -def load_ltxv_model(model_filename, quantizeTransformer = False, base_model_type = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - filename = model_filename[-1] - print(f"Loading '{filename}' model...") +def load_ltxv_model(model_filename, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): from ltx_video.ltxv import LTXV ltxv_model = LTXV( @@ -2261,8 +2247,6 @@ def load_ltxv_model(model_filename, quantizeTransformer = False, base_model_type return ltxv_model, pipe def load_hunyuan_model(model_filename, base_model_type = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - filename = model_filename[-1] - print(f"Loading '{filename}' model...") from hyvideo.hunyuan import HunyuanVideoSampler hunyuan_model = HunyuanVideoSampler.from_pretrained( @@ -2304,25 +2288,33 @@ def get_transformer_model(model): raise Exception("no transformer found") -def load_models(model_filename): - global transformer_filename, transformer_loras_filenames - base_filename = get_base_model_filename(model_filename) - base_model_type = get_model_type(base_filename) - finetune_def = get_model_finetune_def(model_filename) +def load_models(model_type): + global transformer_type, transformer_loras_filenames + model_filename = get_model_filename(model_type=model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + base_model_type = get_base_model_type(model_type) + finetune_def = get_model_finetune_def(model_type) quantizeTransformer = finetune_def !=None and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename - model_family = get_model_family(model_filename) + model_family = get_model_family(model_type) perc_reserved_mem_max = args.perc_reserved_mem_max preload =int(args.preload) save_quantized = args.save_quantized if preload == 0: preload = server_config.get("preload_in_VRAM", 0) new_transformer_loras_filenames = None - dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) + dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None - model_filelist = dependent_models + [model_filename] - for filename in model_filelist: - download_models(filename) + + model_file_list = dependent_models + [model_filename] + model_type_list = dependent_models_types + [model_type] + new_transformer_filename = model_file_list[-1] + if finetune_def != None: + for module_type in finetune_def.get("modules", []): + model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype_policy)) + model_type_list.append(module_type) + + for filename, file_model_type in zip(model_file_list, model_type_list): + download_models(filename, file_model_type) transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy) if quantizeTransformer: transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype @@ -2331,13 +2323,19 @@ def load_models(model_filename): mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_filename = None transformer_loras_filenames = None - new_transformer_filename = model_filelist[-1] + transformer_type = None + for i, filename in enumerate(model_file_list): + if i==0: + print(f"Loading Model '{filename}' ...") + elif "_lora" not in filename: + print(f"Loading Module '{filename}' ...") + if model_family == "wan" : - wan_model, pipe = load_wan_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_wan_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_ltxv_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "hunyuan": - wan_model, pipe = load_hunyuan_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_hunyuan_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) else: raise Exception(f"Model '{new_transformer_filename}' not supported.") wan_model._model_file_name = new_transformer_filename @@ -2369,17 +2367,17 @@ def load_models(model_filename): offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) - transformer_filename = new_transformer_filename transformer_loras_filenames = new_transformer_loras_filenames + transformer_type = model_type return wan_model, offloadobj, pipe["transformer"] if not "P" in preload_model_policy: wan_model, offloadobj, transformer = None, None, None reload_needed = True else: - wan_model, offloadobj, transformer = load_models(transformer_filename) + wan_model, offloadobj, transformer = load_models(transformer_type) if check_loras: - setup_loras(model_filename, transformer, get_lora_dir(transformer_filename), "", None) + setup_loras(transformer_type, transformer, get_lora_dir(transformer_type), "", None) exit() del transformer @@ -2391,17 +2389,11 @@ def get_auto_attention(): return attn return "sdpa" -def get_default_flow(filename, i2v): - return 7.0 if "480p" in filename and i2v else 5.0 - - - - - -def generate_header(model_filename, compile, attention_mode): +def generate_header(model_type, compile, attention_mode): description_container = [""] - get_model_name(model_filename, description_container) + get_model_name(model_type, description_container) + model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) description = description_container[0] header = "