finetuned models support

This commit is contained in:
DeepBeepMeep 2025-06-12 10:00:47 +02:00
parent 43aa414eaf
commit 3749d23d44
23 changed files with 779 additions and 121 deletions

View File

@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates ## 🔥 Latest Updates
### June 12 2025: WanGP v5.6
👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add yourself the support for this model in WanGP by just creating Finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
To celebrate this new feature, I have provided 4 finetuned model definitions:
- *Fast Hunyuan Video* : generate model t2v in only 6 steps
- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps
- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
- *Vace FusioniX*: the ultimate Vace model, as it is a combo of Vace / AccVideo / CausVid ans other models and can generate high quality Wan Controled videos in only 10 steps
Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
### June 11 2025: WanGP v5.5 ### June 11 2025: WanGP v5.5
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ 👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... *Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
@ -96,6 +107,7 @@ For detailed installation instructions for different GPU generations:
### Advanced Features ### Advanced Features
- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization - **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
- **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP
- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation - **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
- **[Command Line Reference](docs/CLI.md)** - All available command line options - **[Command Line Reference](docs/CLI.md)** - All available command line options

14
configs/flf2v_720p.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/i2v.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/i2v_720p.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/phantom_1.3B.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16,
"text_len": 512
}

14
configs/phantom_14B.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/sky_df_1.3.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16,
"text_len": 512
}

14
configs/sky_df_14B.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/t2v.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}

14
configs/t2v_1.3B.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16,
"text_len": 512
}

16
configs/vace_1.3B.json Normal file
View File

@ -0,0 +1,16 @@
{
"_class_name": "VaceWanModel",
"_diffusers_version": "0.30.0",
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16,
"text_len": 512,
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
"vace_in_dim": 96
}

16
configs/vace_14B.json Normal file
View File

@ -0,0 +1,16 @@
{
"_class_name": "VaceWanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_dim": 96
}

78
docs/FINETUNES.md Normal file
View File

@ -0,0 +1,78 @@
# FINETUNES
A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models.
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP, however you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV
Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes.
## Create a new Finetune Model Definition
All the finetune models definitions are json files stored in the **finetunes** sub folder. All the corresponding finetune model weights will be stored in the *ckpts* subfolder and will sit next to the base models.
WanGP comes with a few prebuilt finetune models that you can use as starting points and to get an idea of the structure of the definition file.
A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...).
You can obtain a settings file in several ways:
- In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models)
- From the user interface, go to the base model and click **export settings**
Here are steps:
1) Create a *settings file*
2) Add a **model** subtree with the finetune description
3) Save this file in the subfolder **finetunes**. The name used for the file will be used as its id. It is a good practise to prefix the name of this file with the base model. For instance for a finetune named **Fast*** based on Hunyuan Text 2 Video model *hunyuan_t2v_fast.json*. In this example the Id is *hunyuan_t2v_fast*.
4) Restart WanGP
## Base Models Ids
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are the Ids:
- *t2v*: Wan 2.1 Video text 2
- *i2v*: Wan 2.1 Video image 2 480p
- *i2v_720p*: Wan 2.1 Video image 2 720p
- *vace_14B*: Wan 2.1 Vace 14B
- *hunyuan*: Hunyuan Video text 2 video
- *hunyuan_i2v*: Hunyuan Video image 2 video
## The Model Subtree
- *name* : name of the finetune used to select
- *base* : Id of the base model of the finetune (see previous section)
- *description*: description of the finetune that will appear at the top
- *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing.
- *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance)
- *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model
Example of **model** subtree
```
"model":
{
"name": "Wan text2video FusioniX 14B",
"base" : "t2v",
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
],
"preload_URLs": [
],
"auto_quantize": true
},
```
## Finetune Model Naming Convention
If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a few 32 bits weights), so *bf16* or *fp16* should appear somewhere in the name. If you need examples just look at the **ckpts** subfolder, the naming convention for the base models is the same.
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
## Creating a Quanto Quantized file
If you launch the app with the *--save-quantized* switch, WanGP will create a quantized file in the **ckpts** subfolder just after the model has been loaded. Please note that the model will *bf16* or *fp16* quantized depending on what you chose in the configuration menu.
1) Make sure that in the finetune definition json file there is only a URL or filepath that points to the non quantized model
2) Launch WanGP *python wgp.py --save-quantized*
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
4) Launch a generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder it doesn't already exist.
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
6) Restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
7) Launch a new generation an verify in the terminal window that the right quantized model is loaded
8) In order to share the finetune definition file will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)

View File

@ -0,0 +1,30 @@
{
"model": {
"name": "Hunyuan AccVideo 720p 13B",
"base": "hunyuan",
"description": " AccVideo is a novel efficient distillation method to accelerate video diffusion models with synthetic datset. Our method is 8.5x faster than HunyuanVideo.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/accvideo_hunyuan_video_720_quanto_int8.safetensors"
],
"preload_URLs": [
],
"auto_quantize": true
},
"negative_prompt": "",
"resolution": "832x480",
"video_length": 81,
"seed": 42,
"num_inference_steps": 5,
"flow_shift": 7,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"loras_multipliers": "",
"temporal_upsampling": "",
"spatial_upsampling": "",
"RIFLEx_setting": 0,
"slg_start_perc": 10,
"slg_end_perc": 90,
"prompt_enhancer": "",
"activated_loras": [
]
}

View File

@ -0,0 +1,31 @@
{
"model": {
"name": "Hunyuan Fast Video 720p 13B",
"base": "hunyuan",
"description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/fast_hunyuan_video_720_quanto_int8.safetensors"
],
"preload_URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/fast_hunyuan_video_720_quanto_int8_map.json"
],
"auto_quantize": true
},
"negative_prompt": "",
"resolution": "832x480",
"video_length": 81,
"seed": 42,
"num_inference_steps": 6,
"flow_shift": 17,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"loras_multipliers": "",
"temporal_upsampling": "",
"spatial_upsampling": "",
"RIFLEx_setting": 0,
"slg_start_perc": 10,
"slg_end_perc": 90,
"prompt_enhancer": "",
"activated_loras": [
]
}

View File

@ -0,0 +1,38 @@
{
"model":
{
"name": "Wan text2video FusioniX 14B",
"base" : "t2v",
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
],
"auto_quantize": true
},
"negative_prompt": "",
"prompt": "",
"resolution": "832x480",
"video_length": 81,
"seed": -1,
"num_inference_steps": 8,
"guidance_scale": 1,
"flow_shift": 5,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"multi_images_gen_type": 0,
"tea_cache_setting": 0,
"tea_cache_start_step_perc": 0,
"loras_multipliers": "",
"temporal_upsampling": "",
"spatial_upsampling": "",
"RIFLEx_setting": 0,
"slg_switch": 0,
"slg_start_perc": 10,
"slg_end_perc": 90,
"cfg_star_switch": 0,
"cfg_zero_step": -1,
"prompt_enhancer": "",
"activated_loras": []
}

View File

@ -0,0 +1,38 @@
{
"model":
{
"name": "Vace FusioniX 14B",
"base" : "vace_14B",
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_mfp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_quanto_mfp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_FusioniX_14B_quanto_mbf16_int8.safetensors"
],
"auto_quantize": true
},
"negative_prompt": "",
"prompt": "",
"resolution": "832x480",
"video_length": 81,
"seed": -1,
"num_inference_steps": 10,
"guidance_scale": 1,
"flow_shift": 5,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"multi_images_gen_type": 0,
"tea_cache_setting": 0,
"tea_cache_start_step_perc": 0,
"loras_multipliers": "",
"temporal_upsampling": "",
"spatial_upsampling": "",
"RIFLEx_setting": 0,
"slg_switch": 0,
"slg_start_perc": 10,
"slg_end_perc": 90,
"cfg_star_switch": 0,
"cfg_zero_step": -1,
"prompt_enhancer": "",
"activated_loras": []
}

View File

@ -315,7 +315,7 @@ class Inference(object):
@classmethod @classmethod
def from_pretrained(cls, model_filepath, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , **kwargs): def from_pretrained(cls, model_filepath, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
device = "cuda" device = "cuda"
@ -331,23 +331,25 @@ class Inference(object):
embedded_cfg_scale = 6 embedded_cfg_scale = 6
filepath = model_filepath[0] filepath = model_filepath[0]
i2v_condition_type = None i2v_condition_type = None
i2v_mode = "i2v" in filepath i2v_mode = False
custom = False custom = False
custom_audio = False custom_audio = False
avatar = False avatar = False
if i2v_mode: if base_model_type == "hunyuan_i2v":
model_id = "HYVideo-T/2" model_id = "HYVideo-T/2"
i2v_condition_type = "token_replace" i2v_condition_type = "token_replace"
elif "custom" in filepath: i2v_mode = True
if "audio" in filepath: elif base_model_type == "hunyuan_custom":
model_id = "HYVideo-T/2-custom-audio" model_id = "HYVideo-T/2-custom"
custom_audio = True
elif "edit" in filepath:
model_id = "HYVideo-T/2-custom-edit"
else:
model_id = "HYVideo-T/2-custom"
custom = True custom = True
elif "avatar" in filepath : elif base_model_type == "hunyuan_custom_audio":
model_id = "HYVideo-T/2-custom-audio"
custom_audio = True
custom = True
elif base_model_type == "hunyuan_custom_edit":
model_id = "HYVideo-T/2-custom-edit"
custom = True
elif base_model_type == "hunyuan_avatar":
model_id = "HYVideo-T/2-avatar" model_id = "HYVideo-T/2-avatar"
text_len = 256 text_len = 256
avatar = True avatar = True
@ -385,11 +387,14 @@ class Inference(object):
# model = Inference.load_state_dict(args, model, model_filepath) # model = Inference.load_state_dict(args, model, model_filepath)
# model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt" # model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt"
offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning) offload.load_model_data(model, model_filepath, quantizeTransformer = quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning)
pass pass
# offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors") # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
# offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True) # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(model, filepath, dtype, None)
model.mixed_precision = mixed_precision_transformer model.mixed_precision = mixed_precision_transformer
if model.mixed_precision : if model.mixed_precision :

View File

@ -29,6 +29,8 @@ class DTT2V:
checkpoint_dir, checkpoint_dir,
rank=0, rank=0,
model_filename = None, model_filename = None,
base_model_type = None,
save_quantized = False,
text_encoder_filename = None, text_encoder_filename = None,
quantizeTransformer = False, quantizeTransformer = False,
dtype = torch.bfloat16, dtype = torch.bfloat16,
@ -61,6 +63,7 @@ class DTT2V:
from mmgp import offload from mmgp import offload
# model_filename = "model.safetensors" # model_filename = "model.safetensors"
# model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors" # model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
base_config_file = f"configs/{base_model_type}.json"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json") self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
# offload.load_model_data(self.model, "recam.ckpt") # offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu() # self.model.cpu()
@ -72,6 +75,9 @@ class DTT2V:
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json") # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
self.scheduler = FlowUniPCMultistepScheduler() self.scheduler = FlowUniPCMultistepScheduler()

View File

@ -48,11 +48,13 @@ class WanI2V:
self, self,
config, config,
checkpoint_dir, checkpoint_dir,
model_filename ="", model_filename = None,
text_encoder_filename="", base_model_type= None,
text_encoder_filename= None,
quantizeTransformer = False, quantizeTransformer = False,
dtype = torch.bfloat16, dtype = torch.bfloat16,
VAE_dtype = torch.float32, VAE_dtype = torch.float32,
save_quantized = False,
mixed_precision_transformer = False mixed_precision_transformer = False
): ):
self.device = torch.device(f"cuda") self.device = torch.device(f"cuda")
@ -101,7 +103,8 @@ class WanI2V:
# model_filename = [model_filename, "audio_processor_bf16.safetensors"] # model_filename = [model_filename, "audio_processor_bf16.safetensors"]
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors" # model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
# dtype = torch.float16 # dtype = torch.float16
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json") base_config_file = f"configs/{base_model_type}.json"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath= base_config_file) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True) offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json") # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
@ -110,6 +113,9 @@ class WanI2V:
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt

View File

@ -50,8 +50,10 @@ class WanT2V:
checkpoint_dir, checkpoint_dir,
rank=0, rank=0,
model_filename = None, model_filename = None,
base_model_type = None,
text_encoder_filename = None, text_encoder_filename = None,
quantizeTransformer = False, quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16, dtype = torch.bfloat16,
VAE_dtype = torch.float32, VAE_dtype = torch.float32,
mixed_precision_transformer = False mixed_precision_transformer = False
@ -81,21 +83,24 @@ class WanT2V:
logging.info(f"Creating WanModel from {model_filename[-1]}") logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload from mmgp import offload
# model_filename = "c:/temp/vace1.3/diffusion_pytorch_model.safetensors" # model_filename = "c:/temp/vace1.3/diffusion_pytorch_model.safetensors"
# model_filename = "vace14B_quanto_bf16_int8.safetensors" # model_filename = "Vacefusionix_quanto_fp16_int8.safetensors"
# model_filename = "c:/temp/phantom/Phantom_Wan_14B-00001-of-00006.safetensors" # model_filename = "c:/temp/phantom/Phantom_Wan_14B-00001-of-00006.safetensors"
# config_filename= "c:/temp/phantom/config.json" # config_filename= "c:/temp/phantom/config.json"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)#, forcedConfigPath= config_filename) base_config_file = f"configs/{base_model_type}.json"
# offload.load_model_data(self.model, "e:/vace.safetensors") self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file)#, forcedConfigPath= config_filename)
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
# self.model.to(torch.bfloat16) # self.model.to(torch.bfloat16)
# self.model.cpu() # self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
# dtype = torch.bfloat16 # dtype = torch.bfloat16
# offload.load_model_data(self.model, "ckpts/Wan14BT2VFusioniX_fp16.safetensors")
offload.change_dtype(self.model, dtype, True) offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "wan2.1_phantom_14B_mbf16.safetensors", config_file_path=config_filename) # offload.save_model(self.model, "wanfusionix_fp16.safetensors", config_file_path=base_config_file)
# offload.save_model(self.model, "wan2.1_phantom_14B_quanto_fp16_int8.safetensors", do_quantize= True, config_file_path=config_filename) # offload.save_model(self.model, "wanfusionix_quanto_fp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[-1], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt

View File

@ -226,3 +226,131 @@ def str2bool(v):
return False return False
else: else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)') raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook
def save_quantized_model(model, model_filename, dtype, config_file):
from mmgp import offload
if dtype == torch.bfloat16:
model_filename = model_filename.replace("fp16", "bf16")
elif dtype == torch.float16:
model_filename = model_filename.replace("bf16", "fp16")
if "_fp16" in model_filename:
model_filename = model_filename.replace("_fp16", "_quanto_fp16_int8")
elif "_bf16" in model_filename:
model_filename = model_filename.replace("_bf16", "_quanto_bf16_int8")
else:
pos = model_filename.rfind(".")
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
if not os.path.isfile(model_filename):
offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file)

319
wgp.py
View File

@ -43,8 +43,8 @@ global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip" AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.8" target_mmgp_version = "3.4.9"
WanGP_version = "5.5" WanGP_version = "5.6"
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
from importlib.metadata import version from importlib.metadata import version
@ -161,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return return
inputs["model_filename"] = model_filename inputs["model_filename"] = model_filename
model_filename = get_base_model_filename(model_filename)
prompts = prompt.replace("\r", "").split("\n") prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0: if len(prompts) ==0:
@ -1159,6 +1160,12 @@ def _parse_args():
help="Prevent switch models" help="Prevent switch models"
) )
parser.add_argument(
"--save-quantized",
action="store_true",
help="Save a quantized version of the current model"
)
parser.add_argument( parser.add_argument(
"--preload", "--preload",
type=str, type=str,
@ -1429,7 +1436,7 @@ def _parse_args():
return args return args
def get_lora_dir(model_filename): def get_lora_dir(model_filename):
model_filename = get_base_model_filename(model_filename)
model_family = get_model_family(model_filename) model_family = get_model_family(model_filename)
i2v = test_class_i2v(model_filename) i2v = test_class_i2v(model_filename)
if model_family == "wan": if model_family == "wan":
@ -1543,6 +1550,8 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
os.remove( os.path.join("ckpts" , path)) os.remove( os.path.join("ckpts" , path))
finetunes = {}
finetunes_filemap = {}
wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors",
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
@ -1581,14 +1590,32 @@ model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "
"hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit", "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit",
"hunyuan_avatar" : "hunyuan_video_avatar" } "hunyuan_avatar" : "hunyuan_video_avatar" }
def get_model_finetune_def(model_filename):
model_type = finetunes_filemap.get(model_filename, None )
if model_type == None:
return None
return finetunes.get(model_type, None )
def get_base_model_filename(model_filename):
finetune_def = get_model_finetune_def(model_filename)
if finetune_def == None:
return model_filename
else:
return finetune_def["base_filename"]
def get_model_type(model_filename): def get_model_type(model_filename):
model_type = finetunes_filemap.get(model_filename, None )
if model_type != None:
return model_type
for model_type, signature in model_signatures.items(): for model_type, signature in model_signatures.items():
if signature in model_filename: if signature in model_filename:
return model_type return model_type
raise Exception("Unknown model:" + model_filename) raise Exception("Unknown model:" + model_filename)
def get_model_family(model_filename): def get_model_family(model_filename):
finetune_def = get_model_finetune_def(model_filename)
if finetune_def != None:
return finetune_def["model_family"]
if "wan" in model_filename or "sky" in model_filename: if "wan" in model_filename or "sky" in model_filename:
return "wan" return "wan"
elif "ltxv" in model_filename: elif "ltxv" in model_filename:
@ -1599,10 +1626,15 @@ def get_model_family(model_filename):
raise Exception(f"Unknown model family for model'{model_filename}'") raise Exception(f"Unknown model family for model'{model_filename}'")
def test_class_i2v(model_filename): def test_class_i2v(model_filename):
model_filename = get_base_model_filename(model_filename)
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename or "hunyuan_video_i2v" in model_filename return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename or "hunyuan_video_i2v" in model_filename
def get_model_name(model_filename, description_container = [""]): def get_model_name(model_filename, description_container = [""]):
if "Fun" in model_filename: finetune_def = get_model_finetune_def(model_filename)
if finetune_def != None:
model_name = finetune_def["name"]
description = finetune_def["description"]
elif "Fun" in model_filename:
model_name = "Fun InP image2video" model_name = "Fun InP image2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B" model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model." description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model."
@ -1684,8 +1716,12 @@ def get_model_name(model_filename, description_container = [""]):
def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
signature = model_signatures[model_type] finetune_def = finetunes.get(model_type, None)
choices = [ name for name in transformer_choices if signature in name] if finetune_def != None:
choices = [ "ckpts/" + os.path.basename(path) for path in finetune_def["URLs"] ]
else:
signature = model_signatures[model_type]
choices = [ name for name in transformer_choices if signature in name]
if len(quantization) == 0: if len(quantization) == 0:
quantization = "bf16" quantization = "bf16"
@ -1694,7 +1730,11 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
if len(choices) <= 1: if len(choices) <= 1:
raw_filename = choices[0] raw_filename = choices[0]
else: else:
sub_choices = [ name for name in choices if quantization in name] if quantization in ("int8", "fp8"):
sub_choices = [ name for name in choices if quantization in name]
else:
sub_choices = [ name for name in choices if "quanto" not in name]
if len(sub_choices) > 0: if len(sub_choices) > 0:
dtype_str = "fp16" if dtype == torch.float16 else "bf16" dtype_str = "fp16" if dtype == torch.float16 else "bf16"
new_sub_choices = [ name for name in sub_choices if dtype_str in name] new_sub_choices = [ name for name in sub_choices if dtype_str in name]
@ -1703,7 +1743,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""):
else: else:
raw_filename = choices[0] raw_filename = choices[0]
if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" : if dtype == torch.float16 and not "fp16" in raw_filename and model_family == "wan" and finetune_def != None :
if "quanto_int8" in raw_filename: if "quanto_int8" in raw_filename:
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8") raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
elif "quanto_bf16_int8" in raw_filename: elif "quanto_bf16_int8" in raw_filename:
@ -1739,81 +1779,87 @@ def get_default_settings(filename):
i2v = test_class_i2v(filename) i2v = test_class_i2v(filename)
defaults_filename = get_settings_file_name(filename) defaults_filename = get_settings_file_name(filename)
if not Path(defaults_filename).is_file(): if not Path(defaults_filename).is_file():
ui_defaults = { finetune_def = get_model_finetune_def(filename)
"prompt": get_default_prompt(i2v), if finetune_def != None:
"resolution": "1280x720" if "720p" in filename else "832x480", ui_defaults = finetune_def["settings"]
"video_length": 81, if len(ui_defaults.get("prompt","")) == 0:
"num_inference_steps": 30, ui_defaults["prompt"]= get_default_prompt(i2v)
"seed": -1, else:
"repeat_generation": 1, ui_defaults = {
"multi_images_gen_type": 0, "prompt": get_default_prompt(i2v),
"guidance_scale": 5.0, "resolution": "1280x720" if "720p" in filename else "832x480",
"embedded_guidance_scale" : 6.0, "video_length": 81,
"audio_guidance_scale": 5.0, "num_inference_steps": 30,
"flow_shift": get_default_flow(filename, i2v), "seed": -1,
"negative_prompt": "", "repeat_generation": 1,
"activated_loras": [], "multi_images_gen_type": 0,
"loras_multipliers": "", "guidance_scale": 5.0,
"tea_cache": 0.0, "embedded_guidance_scale" : 6.0,
"tea_cache_start_step_perc": 0, "audio_guidance_scale": 5.0,
"RIFLEx_setting": 0, "flow_shift": get_default_flow(filename, i2v),
"slg_switch": 0, "negative_prompt": "",
"slg_layers": [9], "activated_loras": [],
"slg_start_perc": 10, "loras_multipliers": "",
"slg_end_perc": 90 "tea_cache": 0.0,
} "tea_cache_start_step_perc": 0,
"RIFLEx_setting": 0,
"slg_switch": 0,
"slg_layers": [9],
"slg_start_perc": 10,
"slg_end_perc": 90
}
if get_model_type(filename) in ("hunyuan","hunyuan_i2v"): if get_model_type(filename) in ("hunyuan","hunyuan_i2v"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.0, "guidance_scale": 7.0,
}) })
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 6.0, "guidance_scale": 6.0,
"flow_shift": 8, "flow_shift": 8,
"sliding_window_discard_last_frames" : 0, "sliding_window_discard_last_frames" : 0,
"resolution": "1280x720" if "720p" in filename else "960x544", "resolution": "1280x720" if "720p" in filename else "960x544",
"sliding_window_size" : 121 if "720p" in filename else 97, "sliding_window_size" : 121 if "720p" in filename else 97,
"RIFLEx_setting": 2, "RIFLEx_setting": 2,
"guidance_scale": 6, "guidance_scale": 6,
"flow_shift": 8, "flow_shift": 8,
}) })
if get_model_type(filename) in ("phantom_1.3B", "phantom_14B"): if get_model_type(filename) in ("phantom_1.3B", "phantom_14B"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 5, "flow_shift": 5,
"remove_background_images_ref": 0, "remove_background_images_ref": 0,
# "resolution": "1280x720" # "resolution": "1280x720"
}) })
elif get_model_type(filename) in ("hunyuan_custom"): elif get_model_type(filename) in ("hunyuan_custom"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"resolution": "1280x720", "resolution": "1280x720",
}) })
elif get_model_type(filename) in ("hunyuan_custom_edit"): elif get_model_type(filename) in ("hunyuan_custom_edit"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"video_prompt_type": "MV", "video_prompt_type": "MV",
"sliding_window_size": 129, "sliding_window_size": 129,
}) })
elif get_model_type(filename) in ("hunyuan_avatar"): elif get_model_type(filename) in ("hunyuan_avatar"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 5, "flow_shift": 5,
"tea_cache_start_step_perc": 25, "tea_cache_start_step_perc": 25,
"video_length": 129, "video_length": 129,
"video_prompt_type": "I", "video_prompt_type": "I",
}) })
elif get_model_type(filename) in ("vace_14B"): elif get_model_type(filename) in ("vace_14B"):
ui_defaults.update({ ui_defaults.update({
"sliding_window_discard_last_frames": 0, "sliding_window_discard_last_frames": 0,
}) })
with open(defaults_filename, "w", encoding="utf-8") as f: with open(defaults_filename, "w", encoding="utf-8") as f:
@ -1839,6 +1885,30 @@ def get_default_settings(filename):
ui_defaults["num_inference_steps"] = default_number_steps ui_defaults["num_inference_steps"] = default_number_steps
return ui_defaults return ui_defaults
finetunes_paths = glob.glob( os.path.join("finetunes", "*.json") )
finetunes_paths.sort()
for file_path in finetunes_paths:
finetune_id = os.path.basename(file_path)[:-5]
with open(file_path, "r", encoding="utf-8") as f:
try:
json_def = json.load(f)
except Exception as e:
raise Exception(f"Error while parsing Finetune Definition File '{file_path}': {str(e)}")
finetune_def = json_def["model"]
del json_def["model"]
finetune_def["settings"] = json_def
base_filename = get_model_filename(finetune_def["base"])
finetune_def["base_filename"] = base_filename
finetune_def["model_family"] = get_model_family(base_filename)
finetunes[finetune_id] = finetune_def
for url in finetune_def["URLs"]:
url = url.split("/")[-1]
finetunes_filemap["ckpts/" + url] = finetune_id
model_types += finetunes.keys()
transformer_types = server_config.get("transformer_types", []) transformer_types = server_config.get("transformer_types", [])
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
@ -1904,6 +1974,7 @@ model_filename = ""
# compile = "transformer" # compile = "transformer"
def get_loras_preprocessor(transformer, model_filename): def get_loras_preprocessor(transformer, model_filename):
model_filename = get_base_model_filename(model_filename)
preprocessor = getattr(transformer, "preprocess_loras", None) preprocessor = getattr(transformer, "preprocess_loras", None)
if preprocessor == None: if preprocessor == None:
return None return None
@ -1946,6 +2017,8 @@ def get_hunyuan_text_encoder_filename(text_encoder_quantization):
def download_models(transformer_filename): def download_models(transformer_filename):
def computeList(filename): def computeList(filename):
if filename == None:
return []
pos = filename.rfind("/") pos = filename.rfind("/")
filename = filename[pos+1:] filename = filename[pos+1:]
return [filename] return [filename]
@ -1986,6 +2059,32 @@ def download_models(transformer_filename):
model_family = get_model_family(transformer_filename) model_family = get_model_family(transformer_filename)
finetune_def = get_model_finetune_def(transformer_filename)
if finetune_def != None:
from urllib.request import urlretrieve
from wan.utils.utils import create_progress_hook
if not os.path.isfile(transformer_filename ):
for url in finetune_def["URLs"]:
if transformer_filename in url:
break
if not url.startswith("http"):
raise Exception(f"Model '{transformer_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
try:
urlretrieve(url,transformer_filename, create_progress_hook(transformer_filename))
except Exception as e:
if os.path.isfile(filename): os.remove(transformer_filename)
raise Exception(f"URL '{url}' is invalid for Model '{transformer_filename}' : {str(e)}'")
for url in finetune_def.get("preload_URLs", []):
filename = "ckpts/" + url.split("/")[-1]
if not os.path.isfile(filename ):
if not url.startswith("http"):
raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
try:
urlretrieve(url,filename, create_progress_hook(filename))
except Exception as e:
if os.path.isfile(filename): os.remove(filename)
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
transformer_filename = None
if model_family == "wan": if model_family == "wan":
text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization)
model_def = { model_def = {
@ -2023,8 +2122,6 @@ def download_models(transformer_filename):
offload.default_verboseLevel = verbose_level offload.default_verboseLevel = verbose_level
# download_models(transformer_filename)
def sanitize_file_name(file_name, rep =""): def sanitize_file_name(file_name, rep =""):
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep) return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep)
@ -2111,7 +2208,7 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
filename = model_filename[-1] filename = model_filename[-1]
print(f"Loading '{filename}' model...") print(f"Loading '{filename}' model...")
@ -2130,11 +2227,13 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
config=cfg, config=cfg,
checkpoint_dir="ckpts", checkpoint_dir="ckpts",
model_filename=model_filename, model_filename=model_filename,
base_model_type=base_model_type,
text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization),
quantizeTransformer = quantizeTransformer, quantizeTransformer = quantizeTransformer,
dtype = dtype, dtype = dtype,
VAE_dtype = VAE_dtype, VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
) )
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
@ -2142,7 +2241,7 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
pipe["text_encoder_2"] = wan_model.clip.model pipe["text_encoder_2"] = wan_model.clip.model
return wan_model, pipe return wan_model, pipe
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): def load_ltxv_model(model_filename, quantizeTransformer = False, base_model_type = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
filename = model_filename[-1] filename = model_filename[-1]
print(f"Loading '{filename}' model...") print(f"Loading '{filename}' model...")
from ltx_video.ltxv import LTXV from ltx_video.ltxv import LTXV
@ -2161,18 +2260,20 @@ def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.b
return ltxv_model, pipe return ltxv_model, pipe
def load_hunyuan_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): def load_hunyuan_model(model_filename, base_model_type = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
filename = model_filename[-1] filename = model_filename[-1]
print(f"Loading '{filename}' model...") print(f"Loading '{filename}' model...")
from hyvideo.hunyuan import HunyuanVideoSampler from hyvideo.hunyuan import HunyuanVideoSampler
hunyuan_model = HunyuanVideoSampler.from_pretrained( hunyuan_model = HunyuanVideoSampler.from_pretrained(
model_filepath = model_filename, model_filepath = model_filename,
base_model_type = base_model_type,
text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization),
dtype = dtype, dtype = dtype,
# quantizeTransformer = quantizeTransformer, quantizeTransformer = quantizeTransformer,
VAE_dtype = VAE_dtype, VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
) )
pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae }
@ -2205,9 +2306,15 @@ def get_transformer_model(model):
def load_models(model_filename): def load_models(model_filename):
global transformer_filename, transformer_loras_filenames global transformer_filename, transformer_loras_filenames
base_filename = get_base_model_filename(model_filename)
base_model_type = get_model_type(base_filename)
finetune_def = get_model_finetune_def(model_filename)
quantizeTransformer = finetune_def !=None and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
model_family = get_model_family(model_filename) model_family = get_model_family(model_filename)
perc_reserved_mem_max = args.perc_reserved_mem_max perc_reserved_mem_max = args.perc_reserved_mem_max
preload =int(args.preload) preload =int(args.preload)
save_quantized = args.save_quantized
if preload == 0: if preload == 0:
preload = server_config.get("preload_in_VRAM", 0) preload = server_config.get("preload_in_VRAM", 0)
new_transformer_loras_filenames = None new_transformer_loras_filenames = None
@ -2217,17 +2324,20 @@ def load_models(model_filename):
for filename in model_filelist: for filename in model_filelist:
download_models(filename) download_models(filename)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy) transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
if quantizeTransformer:
transformer_dtype = torch.bfloat16 if "bf16" in model_filename else transformer_dtype
transformer_dtype = torch.float16 if "fp16" in model_filename else transformer_dtype
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
transformer_filename = None transformer_filename = None
transformer_loras_filenames = None transformer_loras_filenames = None
new_transformer_filename = model_filelist[-1] new_transformer_filename = model_filelist[-1]
if model_family == "wan" : if model_family == "wan" :
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) wan_model, pipe = load_wan_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
elif model_family == "ltxv": elif model_family == "ltxv":
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) wan_model, pipe = load_ltxv_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
elif model_family == "hunyuan": elif model_family == "hunyuan":
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) wan_model, pipe = load_hunyuan_model(model_filelist, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
else: else:
raise Exception(f"Model '{new_transformer_filename}' not supported.") raise Exception(f"Model '{new_transformer_filename}' not supported.")
wan_model._model_file_name = new_transformer_filename wan_model._model_file_name = new_transformer_filename
@ -2517,6 +2627,7 @@ def refresh_gallery(state): #, msg
prompt = task["prompt"] prompt = task["prompt"]
params = task["params"] params = task["params"]
model_filename = params["model_filename"] model_filename = params["model_filename"]
model_filename = get_base_model_filename(model_filename)
onemorewindow_visible = "Vace" in model_filename or "diffusion_forcing" in model_filename or "ltxv" in model_filename onemorewindow_visible = "Vace" in model_filename or "diffusion_forcing" in model_filename or "ltxv" in model_filename
enhanced = False enhanced = False
if prompt.startswith("!enhanced!\n"): if prompt.startswith("!enhanced!\n"):
@ -2942,6 +3053,9 @@ def generate_video(
raise gr.Error("Error while loading Loras: " + ", ".join(error_files)) raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
seed = None if seed == -1 else seed seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference # negative_prompt = "" # not applicable in the inference
original_filename = model_filename
model_filename = get_base_model_filename(model_filename)
image2video = test_class_i2v(model_filename) image2video = test_class_i2v(model_filename)
current_video_length = video_length current_video_length = video_length
enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* 16) or RIFLEx_setting == 1 enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* 16) or RIFLEx_setting == 1
@ -3198,7 +3312,11 @@ def generate_video(
fit_into_canvas = fit_canvas fit_into_canvas = fit_canvas
) )
elif hunyuan_custom_edit: elif hunyuan_custom_edit:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] if "P" in video_prompt_type:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
else:
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
send_cmd("progress", progress_args) send_cmd("progress", progress_args)
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps, pose_enhance = "P" in video_prompt_type) src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps, pose_enhance = "P" in video_prompt_type)
if window_no == 1: if window_no == 1:
@ -3432,6 +3550,7 @@ def generate_video(
inputs = get_function_arguments(generate_video, locals()) inputs = get_function_arguments(generate_video, locals())
inputs.pop("send_cmd") inputs.pop("send_cmd")
inputs.pop("task") inputs.pop("task")
inputs["model_filename"] = original_filename
configs = prepare_inputs_dict("metadata", inputs) configs = prepare_inputs_dict("metadata", inputs)
configs["prompt"] = "\n".join(original_prompts) configs["prompt"] = "\n".join(original_prompts)
if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0: if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0:
@ -4197,7 +4316,8 @@ def prepare_inputs_dict(target, inputs ):
if target == "settings": if target == "settings":
return inputs return inputs
model_filename = get_base_model_filename(model_filename)
if not test_class_i2v(model_filename): if not test_class_i2v(model_filename):
inputs.pop("image_prompt_type") inputs.pop("image_prompt_type")
@ -4250,7 +4370,7 @@ def export_settings(state):
settings["model_filename"] = model_filename settings["model_filename"] = model_filename
text = json.dumps(settings, indent=4) text = json.dumps(settings, indent=4)
text_base64 = base64.b64encode(text.encode('utf8')).decode('utf-8') text_base64 = base64.b64encode(text.encode('utf8')).decode('utf-8')
return text_base64 return text_base64, sanitize_file_name(model_type + "_" + datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") + ".json")
def use_video_settings(state, files): def use_video_settings(state, files):
gen = get_gen_info(state) gen = get_gen_info(state)
@ -4581,6 +4701,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
gen["queue"] = [] gen["queue"] = []
state_dict["gen"] = gen state_dict["gen"] = gen
model_filename = get_base_model_filename(model_filename)
preset_to_load = lora_preselected_preset if lora_preset_model == model_filename else "" preset_to_load = lora_preselected_preset if lora_preset_model == model_filename else ""
loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_filename, None, get_lora_dir(model_filename), preset_to_load, None) loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_filename, None, get_lora_dir(model_filename), preset_to_load, None)
@ -5001,7 +5122,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices=[ choices=[
(str(i), i ) for i in range(40) (str(i), i ) for i in range(40)
], ],
value=ui_defaults.get("slg_layers", ["9"]), value=ui_defaults.get("slg_layers", [9]),
multiselect= True, multiselect= True,
label="Skip Layers", label="Skip Layers",
scale= 3 scale= 3
@ -5024,7 +5145,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="CFG Star" label="CFG Star"
) )
with gr.Row(): with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)") cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = not (hunyuan_i2v or hunyuan_video_avatar or hunyuan_i2v or hunyuan_video_custom))
with gr.Tab("Sliding Window", visible= sliding_window_enabled) as sliding_window_tab: with gr.Tab("Sliding Window", visible= sliding_window_enabled) as sliding_window_tab:
with gr.Column(): with gr.Column():
@ -5071,6 +5192,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row(): with gr.Row():
settings_file = gr.File(height=41,label="Load Settings From Video / Json") settings_file = gr.File(height=41,label="Load Settings From Video / Json")
settings_base64_output = gr.Text(interactive= False, visible=False, value = "") settings_base64_output = gr.Text(interactive= False, visible=False, value = "")
settings_filename = gr.Text(interactive= False, visible=False, value = "")
if not update_form: if not update_form:
with gr.Column(): with gr.Column():
gen_status = gr.Text(interactive= False, label = "Status") gen_status = gr.Text(interactive= False, label = "Status")
@ -5214,10 +5337,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs= None outputs= None
).then(fn=export_settings, ).then(fn=export_settings,
inputs =[state], inputs =[state],
outputs= [settings_base64_output] outputs= [settings_base64_output, settings_filename]
).then( ).then(
fn=None, fn=None,
inputs=[settings_base64_output], inputs=[settings_base64_output, settings_filename],
outputs=None, outputs=None,
js=trigger_settings_download_js js=trigger_settings_download_js
) )
@ -5799,7 +5922,7 @@ def get_js():
""" """
trigger_settings_download_js = """ trigger_settings_download_js = """
(base64String) => { (base64String, filename) => {
if (!base64String) { if (!base64String) {
console.log("No base64 settings data received, skipping download."); console.log("No base64 settings data received, skipping download.");
return; return;
@ -5817,7 +5940,7 @@ def get_js():
const a = document.createElement('a'); const a = document.createElement('a');
a.style.display = 'none'; a.style.display = 'none';
a.href = url; a.href = url;
a.download = 'settings.json'; a.download = filename;
document.body.appendChild(a); document.body.appendChild(a);
a.click(); a.click();