Flux Kontext and more

This commit is contained in:
deepbeepmeep 2025-07-15 22:26:56 +02:00
parent 37f41804a6
commit 64c59c15d9
21 changed files with 734 additions and 392 deletions

View File

@ -20,6 +20,26 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates ## 🔥 Latest Updates
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
Also in the news:
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
- *Film Grain* post processing to add a vintage look at your video
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me ### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.

View File

@ -2,18 +2,15 @@
"model": { "model": {
"name": "Flux Dev Kontext 12B", "name": "Flux Dev Kontext 12B",
"architecture": "flux_dev_kontext", "architecture": "flux_dev_kontext",
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that the output resolution is modified by Flux Kontext and may not be what you requested.", "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.",
"URLs": [ "URLs": [
"c:/temp/kontext/flux1_kontext_dev_bf16.safetensors",
"c:/temp/kontext/flux1_kontext_dev_quanto_bf16_int8.safetensors"
],
"URLs2": [
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
] ]
}, },
"prompt": "add a hat",
"resolution": "1280x720", "resolution": "1280x720",
"video_length": "1" "video_length": 1
} }

13
defaults/t2i.json Normal file
View File

@ -0,0 +1,13 @@
{
"model": {
"name": "Wan2.1 text2image 14B",
"architecture": "t2v",
"description": "The original Wan Text 2 Video model configured to generate an image instead of a video.",
"image_outputs": true,
"URLs": "t2v"
},
"video_length": 1,
"resolution": "1280x720"
}

View File

@ -15,7 +15,7 @@
"seed": -1, "seed": -1,
"num_inference_steps": 10, "num_inference_steps": 10,
"guidance_scale": 1, "guidance_scale": 1,
"flow_shift": 5, "flow_shift": 2,
"embedded_guidance_scale": 6, "embedded_guidance_scale": 6,
"repeat_generation": 1, "repeat_generation": 1,
"multi_images_gen_type": 0, "multi_images_gen_type": 0,

View File

@ -0,0 +1,16 @@
{
"model": {
"name": "Vace FusioniX image2image 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"image_outputs": true,
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": "t2v_fusionix"
},
"resolution": "1280x720",
"guidance_scale": 1,
"num_inference_steps": 10,
"video_length": 1
}

View File

@ -2,22 +2,30 @@
A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models. A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models.
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP, however you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface. As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP. However you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
WanGP finetune system can be also used to tweak default models : for instance you can add on top of an existing model some loras that will be always applied transparently.
Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV
All the finetunes definitions files should be stored in the *finetunes/* subfolder.
Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes. Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes.
## Create a new Finetune Model Definition
All the finetune models definitions are json files stored in the **finetunes** sub folder. All the corresponding finetune model weights will be stored in the *ckpts* subfolder and will sit next to the base models.
WanGP comes with a few prebuilt finetune models that you can use as starting points and to get an idea of the structure of the definition file.
## Create a new Finetune Model Definition
All the finetune models definitions are json files stored in the **finetunes/** sub folder. All the corresponding finetune model weights when they are downloaded will be stored in the *ckpts/* subfolder and will sit next to the base models.
All the models used by WanGP are also described using the finetunes json format and can be found in the **defaults/** subfolder. Please dont modify any file in the **defaults/** folder.
However you can use these files as starting points for new definition files and to get an idea of the structure of a definition file. If you want to change how a base model is handled (title, default settings, path to model weights, …) you may override any property of the default finetunes definition file by creating a new file in the finetunes folder with the same name. Everything will happen as if the two models will be merged property by property with a higher priority given to the finetunes model definition.
A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...). A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...).
You can obtain a settings file in several ways: You can obtain a settings file in several ways:
- In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models) - In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models)
- From the user interface, go to the base model and click **export settings** - From the user interface, select the base model for which you want to create a finetune and click **export settings**
Here are steps: Here are steps:
1) Create a *settings file* 1) Create a *settings file*
@ -26,22 +34,37 @@ Here are steps:
4) Restart WanGP 4) Restart WanGP
## Architecture Models Ids ## Architecture Models Ids
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are Architecture Ids: A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are some Architecture Ids:
- *t2v*: Wan 2.1 Video text 2 - *t2v*: Wan 2.1 Video text 2 video
- *i2v*: Wan 2.1 Video image 2 480p - *i2v*: Wan 2.1 Video image 2 video 480p and 720p
- *i2v_720p*: Wan 2.1 Video image 2 720p
- *vace_14B*: Wan 2.1 Vace 14B - *vace_14B*: Wan 2.1 Vace 14B
- *hunyuan*: Hunyuan Video text 2 video - *hunyuan*: Hunyuan Video text 2 video
- *hunyuan_i2v*: Hunyuan Video image 2 video - *hunyuan_i2v*: Hunyuan Video image 2 video
Any file name in the defaults subfolder (without the json extension) corresponds to an architecture id.
Please note that weights of some architectures correspond to a combination of weight of a one architecture which are completed by the weights of one more or modules.
A module is a set a weights that are insufficient to be model by itself but that can be added to an existing model to extend its capabilities.
For instance if one adds a module *vace_14B* on top of a model with architecture *t2v* one gets get a model with the *vace_14B* architecture. Here *vace_14B* stands for both an architecture name and a module name. The module system allows you to reuse shared weights between models.
## The Model Subtree ## The Model Subtree
- *name* : name of the finetune used to select - *name* : name of the finetune used to select
- *architecture* : architecture Id of the base model of the finetune (see previous section) - *architecture* : architecture Id of the base model of the finetune (see previous section)
- *description*: description of the finetune that will appear at the top - *description*: description of the finetune that will appear at the top
- *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing. - *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing.
- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. So far the only module supported is Vace 14B (its id is *vace_14B*). For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. - *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module.
- *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance) - *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance)
-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above.
- *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model - *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model
-*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it.
-*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame.
In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse.
For example lets say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file.
Example of **model** subtree Example of **model** subtree
``` ```

View File

@ -6,18 +6,19 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a
Loras are organized in different folders based on the model they're designed for: Loras are organized in different folders based on the model they're designed for:
### Text-to-Video Models ### Wan Text-to-Video Models
- `loras/` - General t2v loras - `loras/` - General t2v loras
- `loras/1.3B/` - Loras specifically for 1.3B models - `loras/1.3B/` - Loras specifically for 1.3B models
- `loras/14B/` - Loras specifically for 14B models - `loras/14B/` - Loras specifically for 14B models
### Image-to-Video Models ### Wan Image-to-Video Models
- `loras_i2v/` - Image-to-video loras - `loras_i2v/` - Image-to-video loras
### Other Models ### Other Models
- `loras_hunyuan/` - Hunyuan Video t2v loras - `loras_hunyuan/` - Hunyuan Video t2v loras
- `loras_hunyuan_i2v/` - Hunyuan Video i2v loras - `loras_hunyuan_i2v/` - Hunyuan Video i2v loras
- `loras_ltxv/` - LTX Video loras - `loras_ltxv/` - LTX Video loras
- `loras_flux/` - Flux loras
## Custom Lora Directory ## Custom Lora Directory
@ -64,7 +65,7 @@ For dynamic effects over generation steps, use comma-separated values:
## Lora Presets ## Lora Presets
Presets are combinations of loras with predefined multipliers and prompts. Lora Presets are combinations of loras with predefined multipliers and prompts.
### Creating Presets ### Creating Presets
1. Configure your loras and multipliers 1. Configure your loras and multipliers
@ -95,16 +96,36 @@ WanGP supports multiple lora formats:
- **Replicate** format - **Replicate** format
- **Standard PyTorch** (.pt, .pth) - **Standard PyTorch** (.pt, .pth)
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models ## Loras Accelerators
Most Loras are used to apply a specific style or to alter the content of the output of the generated video.
However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video.
You will find most *Loras Accelerators* here:
https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators
### Setup Instructions ### Setup Instructions
1. Download the Lora: 1. Download the Lora
``` 2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors
``` ## FusioniX (or FusionX) Lora
2. Place in your `loras/` directory If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v
### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
2. Enable Advanced Mode
3. In Advanced Generation Tab:
- Set Guidance Scale = 1
- Set Shift Scale = 2
4. In Advanced Lora Tab:
- Select CausVid Lora
- Set multiplier to 1
5. Set generation steps from 8-10
6. Generate!
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models
You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors*
### Usage ### Usage
1. Select a Wan t2v or i2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) 1. Select a Wan t2v or i2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
@ -118,17 +139,10 @@ Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distil
5. Set generation steps to 2-8 5. Set generation steps to 2-8
6. Generate! 6. Generate!
## CausVid Lora (Video Generation Accelerator) ## CausVid Lora (Video Generation Accelerator)
CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement. CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement.
### Setup Instructions
1. Download the CausVid Lora:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors
```
2. Place in your `loras/` directory
### Usage ### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
2. Enable Advanced Mode 2. Enable Advanced Mode
@ -149,25 +163,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe
*Note: Lower steps = lower quality (especially motion)* *Note: Lower steps = lower quality (especially motion)*
## AccVid Lora (Video Generation Accelerator) ## AccVid Lora (Video Generation Accelerator)
AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1). AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1).
### Setup Instructions
1. Download the AccVid Lora:
- for t2v models:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors
```
- for i2v models:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors
```
2. Place in your `loras/` directory or `loras_i2v/` directory
### Usage ### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model
@ -268,6 +267,7 @@ In the video, a man is presented. The man is in a city and looks at his watch.
--lora-dir-hunyuan path # Path to Hunyuan t2v loras --lora-dir-hunyuan path # Path to Hunyuan t2v loras
--lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras --lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras
--lora-dir-ltxv path # Path to LTX Video loras --lora-dir-ltxv path # Path to LTX Video loras
--lora-dir-flux path # Path to Flux loras
--lora-preset preset # Load preset on startup --lora-preset preset # Load preset on startup
--check-loras # Filter incompatible loras --check-loras # Filter incompatible loras
``` ```

View File

@ -2,6 +2,8 @@
WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations. WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations.
Most models can combined with Loras Accelerators (check the Lora guide) to accelerate the generation of a video x2 or x3 with little quality loss
## Wan 2.1 Text2Video Models ## Wan 2.1 Text2Video Models
Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images. Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images.
@ -65,6 +67,12 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
## Wan 2.1 Specialized Models ## Wan 2.1 Specialized Models
#### Multitalk
- **Type**: Multi Talking head animation
- **Input**: Voice track + image
- **Works on**: People
- **Use case**: Lip-sync and voice-driven animation for up to two people
#### FantasySpeaking #### FantasySpeaking
- **Type**: Talking head animation - **Type**: Talking head animation
- **Input**: Voice track + image - **Input**: Voice track + image
@ -82,7 +90,7 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
- **Requirements**: 81+ frame input videos, 15+ denoising steps - **Requirements**: 81+ frame input videos, 15+ denoising steps
- **Use case**: View same scene from different angles - **Use case**: View same scene from different angles
#### Sky Reels v2 #### Sky Reels v2 Diffusion
- **Type**: Diffusion Forcing model - **Type**: Diffusion Forcing model
- **Specialty**: "Infinite length" videos - **Specialty**: "Infinite length" videos
- **Features**: High quality continuous generation - **Features**: High quality continuous generation
@ -107,22 +115,6 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
<BR> <BR>
## Wan Special Loras
### Safe-Forcing lightx2v Lora
- **Type**: Distilled model (Lora implementation)
- **Speed**: 4-8 steps generation, 2x faster (no classifier free guidance)
- **Compatible**: Works with t2v and i2v Wan 14B models
- **Setup**: Requires Safe-Forcing lightx2v Lora (see [LORAS.md](LORAS.md))
### Causvid Lora
- **Type**: Distilled model (Lora implementation)
- **Speed**: 4-12 steps generation, 2x faster (no classifier free guidance)
- **Compatible**: Works with Wan 14B models
- **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md))
<BR>
## Hunyuan Video Models ## Hunyuan Video Models

View File

View File

@ -65,7 +65,7 @@ class model_factory:
fit_into_canvas = None, fit_into_canvas = None,
callback = None, callback = None,
loras_slists = None, loras_slists = None,
frame_num = 1, batch_size = 1,
**bbargs **bbargs
): ):
@ -89,7 +89,7 @@ class model_factory:
img_cond=image_ref, img_cond=image_ref,
target_width=width, target_width=width,
target_height=height, target_height=height,
bs=frame_num, bs=batch_size,
seed=seed, seed=seed,
device="cuda", device="cuda",
) )

View File

@ -0,0 +1,21 @@
# Thanks to https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/film_grain.py
import torch
def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation: float = 0.5):
device = images.device
images = images.permute(1, 2 ,3 ,0)
images.add_(1.).div_(2.)
grain = torch.randn_like(images, device=device)
grain[:, :, :, 0] *= 2
grain[:, :, :, 2] *= 3
grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat(
1, 1, 1, 3
) * (1 - saturation)
# Blend the grain with the image
noised_images = images + grain_intensity * grain
noised_images.clamp_(0, 1)
noised_images.sub_(.5).mul_(2.)
noised_images = noised_images.permute(3, 0, 1 ,2)
return noised_images

View File

@ -65,6 +65,7 @@ def get_frames_from_image(image_input, image_state):
Return Return
[[0:nearest_frame], [nearest_frame:], nearest_frame] [[0:nearest_frame], [nearest_frame:], nearest_frame]
""" """
load_sam()
user_name = time.time() user_name = time.time()
frames = [image_input] * 2 # hardcode: mimic a video with 2 frames frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
@ -89,7 +90,7 @@ def get_frames_from_image(image_input, image_state):
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=False), \ gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=False), gr.update(visible=True), \
gr.update(visible=True) gr.update(visible=True)
@ -103,6 +104,8 @@ def get_frames_from_video(video_input, video_state):
[[0:nearest_frame], [nearest_frame:], nearest_frame] [[0:nearest_frame], [nearest_frame:], nearest_frame]
""" """
load_sam()
while model == None: while model == None:
time.sleep(1) time.sleep(1)
@ -273,6 +276,20 @@ def save_video(frames, output_path, fps):
return output_path return output_path
def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255)
xmin = min(cols)
xmax = max(cols) + 1
ymin = min(rows)
ymax = max(rows) + 1
xmin = max(xmin, 0)
ymin = max(ymin, 0)
xmax = min(xmax, mask.shape[1])
ymax = min(ymax, mask.shape[0])
box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box]
return box
# image matting # image matting
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter): def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
@ -320,9 +337,17 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
foreground = output_frames foreground = output_frames
foreground_output = Image.fromarray(foreground[-1]) foreground_output = Image.fromarray(foreground[-1])
alpha_output = Image.fromarray(alpha[-1][:,:,0]) alpha_output = alpha[-1][:,:,0]
frame_temp = alpha_output.copy()
return foreground_output, gr.update(visible=True) alpha_output[frame_temp > 127] = 0
alpha_output[frame_temp <= 127] = 255
bbox_info = mask_to_xyxy_box(alpha_output)
h = alpha_output.shape[0]
w = alpha_output.shape[1]
bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ]
bbox_info = ":".join(bbox_info)
alpha_output = Image.fromarray(alpha_output)
return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True)
# video matting # video matting
def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
@ -469,6 +494,13 @@ def restart():
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
def load_sam():
global model_loaded
global model
global matanyone_model
model.samcontroler.sam_controler.model.to(arg_device)
matanyone_model.to(arg_device)
def load_unload_models(selected): def load_unload_models(selected):
global model_loaded global model_loaded
global model global model
@ -476,8 +508,7 @@ def load_unload_models(selected):
if selected: if selected:
# print("Matanyone Tab Selected") # print("Matanyone Tab Selected")
if model_loaded: if model_loaded:
model.samcontroler.sam_controler.model.to(arg_device) load_sam()
matanyone_model.to(arg_device)
else: else:
# args, defined in track_anything.py # args, defined in track_anything.py
sam_checkpoint_url_dict = { sam_checkpoint_url_dict = {
@ -522,12 +553,16 @@ def export_to_vace_video_input(foreground_video_output):
def export_image(image_refs, image_output): def export_image(image_refs, image_output):
gr.Info("Masked Image transferred to Current Video") gr.Info("Masked Image transferred to Current Video")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
if image_refs == None: if image_refs == None:
image_refs =[] image_refs =[]
image_refs.append( image_output) image_refs.append( image_output)
return image_refs return image_refs
def export_image_mask(image_input, image_mask):
gr.Info("Input Image & Mask transferred to Current Video")
return Image.fromarray(image_input), image_mask
def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output): def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
gr.Info("Original Video and Full Mask have been transferred") gr.Info("Original Video and Full Mask have been transferred")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
@ -543,7 +578,7 @@ def teleport_to_video_tab(tab_state):
return gr.Tabs(selected="video_gen") return gr.Tabs(selected="video_gen")
def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, vace_image_refs): def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) # my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
@ -677,7 +712,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button") foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
with gr.Column(scale=2): with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video") alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") export_image_mask_btn = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
with gr.Row(): with gr.Row():
with gr.Row(visible= False): with gr.Row(visible= False):
export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False) export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False)
@ -696,7 +731,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
], ],
outputs=[video_state, video_info, template_frame, outputs=[video_state, video_info, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title] foreground_video_output, alpha_video_output, foreground_output_button, export_image_mask_btn, mask_dropdown, step2_title]
) )
# second step: select images from slider # second step: select images from slider
@ -755,7 +790,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_video_output, alpha_video_output, foreground_video_output, alpha_video_output,
template_frame, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
], ],
queue=False, queue=False,
show_progress=False) show_progress=False)
@ -770,7 +805,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_video_output, alpha_video_output, foreground_video_output, alpha_video_output,
template_frame, template_frame,
image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
], ],
queue=False, queue=False,
show_progress=False) show_progress=False)
@ -872,15 +907,19 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
# output image # output image
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image") foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image")
alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image")
with gr.Row(equal_height=True):
bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False)
with gr.Row(): with gr.Row():
with gr.Row(): # with gr.Row():
export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button")
with gr.Column(scale=2, visible= False): # with gr.Column(scale=2, visible= True):
alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image") export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
# first step: get the image information # first step: get the image information
extract_frames_button.click( extract_frames_button.click(
@ -890,9 +929,17 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
], ],
outputs=[image_state, image_info, template_frame, outputs=[image_state, image_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_image_output, alpha_image_output, export_image_btn, alpha_output_button, mask_dropdown, step2_title] foreground_image_output, alpha_image_output, bbox_info, export_image_btn, export_image_mask_btn, mask_dropdown, step2_title]
) )
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [image_state, click_state,],
outputs = [template_frame,click_state],
)
# second step: select images from slider # second step: select images from slider
image_selection_slider.release(fn=select_image_template, image_selection_slider.release(fn=select_image_template,
inputs=[image_selection_slider, image_state, interactive_state], inputs=[image_selection_slider, image_state, interactive_state],
@ -925,7 +972,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
matting_button.click( matting_button.click(
fn=image_matting, fn=image_matting,
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider], inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
outputs=[foreground_image_output, export_image_btn] outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
) )

View File

@ -61,6 +61,7 @@ class WanAny2V:
checkpoint_dir, checkpoint_dir,
model_filename = None, model_filename = None,
model_type = None, model_type = None,
model_def = None,
base_model_type = None, base_model_type = None,
text_encoder_filename = None, text_encoder_filename = None,
quantizeTransformer = False, quantizeTransformer = False,
@ -75,7 +76,8 @@ class WanAny2V:
self.dtype = dtype self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype self.param_dtype = config.param_dtype
self.model_def = model_def
self.image_outputs = model_def.get("image_outputs", False)
self.text_encoder = T5EncoderModel( self.text_encoder = T5EncoderModel(
text_len=config.text_len, text_len=config.text_len,
dtype=config.t5_dtype, dtype=config.t5_dtype,
@ -106,7 +108,7 @@ class WanAny2V:
# config = json.load(f) # config = json.load(f)
# from mmgp import safetensors2 # from mmgp import safetensors2
# sd = safetensors2.torch_load_file(xmodel_filename) # sd = safetensors2.torch_load_file(xmodel_filename)
# model_filename = "c:/temp/fflf/diffusion_pytorch_model-00001-of-00007.safetensors" # model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors"
base_config_file = f"configs/{base_model_type}.json" base_config_file = f"configs/{base_model_type}.json"
forcedConfigPath = base_config_file if len(model_filename) > 1 else None forcedConfigPath = base_config_file if len(model_filename) > 1 else None
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
@ -208,7 +210,7 @@ class WanAny2V:
if refs is not None: if refs is not None:
length = len(refs) length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :]) mask_pad = torch.zeros(mask.shape[0], length, *mask.shape[-2:], dtype=mask.dtype, device=mask.device)
mask = torch.cat((mask_pad, mask), dim=1) mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask) result_masks.append(mask)
return result_masks return result_masks
@ -327,20 +329,6 @@ class WanAny2V:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
# else:
# assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def get_vae_latents(self, ref_images, device, tile_size= 0): def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = [] ref_vae_latents = []
for ref_image in ref_images: for ref_image in ref_images:
@ -366,6 +354,7 @@ class WanAny2V:
height = 720, height = 720,
fit_into_canvas = True, fit_into_canvas = True,
frame_num=81, frame_num=81,
batch_size = 1,
shift=5.0, shift=5.0,
sample_solver='unipc', sample_solver='unipc',
sampling_steps=50, sampling_steps=50,
@ -397,6 +386,7 @@ class WanAny2V:
NAG_alpha = 0.5, NAG_alpha = 0.5,
offloadobj = None, offloadobj = None,
apg_switch = False, apg_switch = False,
speakers_bboxes = None,
**bbargs **bbargs
): ):
@ -554,8 +544,8 @@ class WanAny2V:
overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4) overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4)
if overlapped_latents != None: if overlapped_latents != None:
# disabled because looks worse # disabled because looks worse
if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:] if False and overlapped_latents_frames_num > 1: lat_y[:, :, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone() extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
y = torch.concat([msk, lat_y]) y = torch.concat([msk, lat_y])
lat_y = None lat_y = None
kwargs.update({'clip_fea': clip_context, 'y': y}) kwargs.update({'clip_fea': clip_context, 'y': y})
@ -586,7 +576,7 @@ class WanAny2V:
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else: else:
overlapped_latents_frames_num = overlapped_frames_num = 0 overlapped_latents_frames_num = overlapped_frames_num = 0
if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
latent_keep_frames = [] latent_keep_frames = []
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0: if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
@ -609,6 +599,7 @@ class WanAny2V:
input_ref_images = self.get_vae_latents(input_ref_images, self.device) input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = torch.zeros_like(input_ref_images) input_ref_images_neg = torch.zeros_like(input_ref_images)
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
trim_frames = input_ref_images.shape[1]
# Vace # Vace
if vace : if vace :
@ -633,8 +624,8 @@ class WanAny2V:
context_scale = context_scale if context_scale != None else [1.0] * len(z) context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
if overlapped_latents != None : if overlapped_latents != None :
overlapped_latents_size = overlapped_latents.shape[1] overlapped_latents_size = overlapped_latents.shape[2]
extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone() extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
target_shape = list(z0[0].shape) target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2) target_shape[0] = int(target_shape[0] / 2)
@ -649,7 +640,7 @@ class WanAny2V:
from wan.multitalk.multitalk import get_target_masks from wan.multitalk.multitalk import get_target_masks
audio_proj = [audio.to(self.dtype) for audio in audio_proj] audio_proj = [audio.to(self.dtype) for audio in audio_proj]
human_no = len(audio_proj[0]) human_no = len(audio_proj[0])
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None
if fantasy and audio_proj != None: if fantasy and audio_proj != None:
kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, }) kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, })
@ -658,8 +649,8 @@ class WanAny2V:
if self._interrupt: if self._interrupt:
return None return None
expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes # Ropes
batch_size = 1
if target_camera != None: if target_camera != None:
shape = list(target_shape[1:]) shape = list(target_shape[1:])
shape[0] *= 2 shape[0] *= 2
@ -698,14 +689,14 @@ class WanAny2V:
if sample_scheduler != None: if sample_scheduler != None:
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
# b, c, lat_f, lat_h, lat_w
latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if apg_switch != 0: if apg_switch != 0:
apg_momentum = -0.75 apg_momentum = -0.75
apg_norm_threshold = 55 apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum) text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum)
# self.image_outputs = False
# denoising # denoising
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
@ -715,36 +706,36 @@ class WanAny2V:
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
sigma = t / 1000 sigma = t / 1000
noise = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start: if inject_from_start:
new_latents = latents.clone() new_latents = latents.clone()
new_latents[:, :source_latents.shape[1] ] = noise[:, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0)
for latent_no, keep_latent in enumerate(latent_keep_frames): for latent_no, keep_latent in enumerate(latent_keep_frames):
if not keep_latent: if not keep_latent:
new_latents[:, latent_no:latent_no+1 ] = latents[:, latent_no:latent_no+1] new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
latents = new_latents latents = new_latents
new_latents = None new_latents = None
else: else:
latents = noise * sigma + (1 - sigma) * source_latents latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0)
noise = None noise = None
if extended_overlapped_latents != None: if extended_overlapped_latents != None:
latent_noise_factor = t / 1000 latent_noise_factor = t / 1000
latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
if vace: if vace:
overlap_noise_factor = overlap_noise / 1000 overlap_noise_factor = overlap_noise / 1000
for zz in z: for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
if target_camera != None: if target_camera != None:
latent_model_input = torch.cat([latents, source_latents], dim=1) latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!!
else: else:
latent_model_input = latents latent_model_input = latents
if phantom: if phantom:
gen_args = { gen_args = {
"x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]), [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"context": [context, context_null, context_null] , "context": [context, context_null, context_null] ,
} }
elif fantasy: elif fantasy:
@ -832,38 +823,41 @@ class WanAny2V:
if sample_solver == "euler": if sample_solver == "euler":
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
dt = dt / self.num_timesteps dt = dt / self.num_timesteps
latents = latents - noise_pred * dt[:, None, None, None] latents = latents - noise_pred * dt[:, None, None, None, None]
else: else:
temp_x0 = sample_scheduler.step( latents = sample_scheduler.step(
noise_pred[:, :target_shape[1]].unsqueeze(0), noise_pred[:, :, :target_shape[1]],
t, t,
latents.unsqueeze(0), latents,
**scheduler_kwargs)[0] **scheduler_kwargs)[0]
latents = temp_x0.squeeze(0)
del temp_x0
if callback is not None: if callback is not None:
callback(i, latents, False) latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
callback(i, latents_preview[0], False)
latents_preview = None
x0 = [latents] if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone()
x0 =latents.unbind(dim=0)
if chipmunk: if chipmunk:
self.model.release_chipmunk() # need to add it at every exit when in prod self.model.release_chipmunk() # need to add it at every exit when in prod
if return_latent_slice != None:
latent_slice = latents[:, return_latent_slice].clone()
if vace:
# vace post processing
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
else:
if phantom and input_ref_images != None:
trim_frames = input_ref_images.shape[1]
if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0]
videos = self.vae.decode(x0, VAE_tile_size) videos = self.vae.decode(x0, VAE_tile_size)
if self.image_outputs:
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
else:
videos = videos[0] # return only first video
if return_latent_slice != None: if return_latent_slice != None:
return { "x" : videos[0], "latent_slice" : latent_slice } return { "x" : videos, "latent_slice" : latent_slice }
return videos[0] return videos
def adapt_vace_model(self): def adapt_vace_model(self):
model = self.model model = self.model

View File

@ -31,6 +31,7 @@ class DTT2V:
rank=0, rank=0,
model_filename = None, model_filename = None,
model_type = None, model_type = None,
model_def = None,
base_model_type = None, base_model_type = None,
save_quantized = False, save_quantized = False,
text_encoder_filename = None, text_encoder_filename = None,
@ -53,6 +54,8 @@ class DTT2V:
checkpoint_path=text_encoder_filename, checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None) shard_fn= None)
self.model_def = model_def
self.image_outputs = model_def.get("image_outputs", False)
self.vae_stride = config.vae_stride self.vae_stride = config.vae_stride
self.patch_size = config.patch_size self.patch_size = config.patch_size
@ -202,6 +205,7 @@ class DTT2V:
width: int = 832, width: int = 832,
fit_into_canvas = True, fit_into_canvas = True,
frame_num: int = 97, frame_num: int = 97,
batch_size = 1,
sampling_steps: int = 50, sampling_steps: int = 50,
shift: float = 1.0, shift: float = 1.0,
guide_scale: float = 5.0, guide_scale: float = 5.0,
@ -224,6 +228,7 @@ class DTT2V:
generator = torch.Generator(device=self.device) generator = torch.Generator(device=self.device)
generator.manual_seed(seed) generator.manual_seed(seed)
self._guidance_scale = guide_scale self._guidance_scale = guide_scale
if frame_num > 1:
frame_num = max(17, frame_num) # must match causal_block_size for value of 5 frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
@ -297,12 +302,12 @@ class DTT2V:
prefix_video = prefix_video[:, : predix_video_latent_length] prefix_video = prefix_video[:, : predix_video_latent_length]
base_num_frames_iter = latent_length base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latent_shape = [batch_size, 16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents( latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator latent_shape, dtype=torch.float32, device=self.device, generator=generator
) )
if prefix_video is not None: if prefix_video is not None:
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32) latents[:, :, :predix_video_latent_length] = prefix_video.to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter, base_num_frames_iter,
init_timesteps, init_timesteps,
@ -340,7 +345,7 @@ class DTT2V:
else: else:
self.model.enable_cache = None self.model.enable_cache = None
from mmgp import offload from mmgp import offload
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False) freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False)
kwrags = { kwrags = {
"freqs" :freqs, "freqs" :freqs,
"fps" : fps_embeds, "fps" : fps_embeds,
@ -358,15 +363,15 @@ class DTT2V:
update_mask_i = step_update_mask[i] update_mask_i = step_update_mask[i]
valid_interval_start, valid_interval_end = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval[i]
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone() latent_model_input = latents[:, :, valid_interval_start:valid_interval_end, :, :].clone()
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * overlap_noise noise_factor = 0.001 * overlap_noise
timestep_for_noised_condition = overlap_noise timestep_for_noised_condition = overlap_noise
latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( latent_model_input[:, :, valid_interval_start:predix_video_latent_length] = (
latent_model_input[:, valid_interval_start:predix_video_latent_length] latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor) * (1.0 - noise_factor)
+ torch.randn_like( + torch.randn_like(
latent_model_input[:, valid_interval_start:predix_video_latent_length] latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
) )
* noise_factor * noise_factor
) )
@ -417,18 +422,27 @@ class DTT2V:
del noise_pred_cond, noise_pred_uncond del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end): for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item(): if update_mask_i[idx].item():
latents[:, idx] = sample_schedulers[idx].step( latents[:, :, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start], noise_pred[:, :, idx - valid_interval_start],
timestep_i[idx], timestep_i[idx],
latents[:, idx], latents[:, :, idx],
return_dict=False, return_dict=False,
generator=generator, generator=generator,
)[0] )[0]
sample_schedulers_counter[idx] += 1 sample_schedulers_counter[idx] += 1
if callback is not None: if callback is not None:
callback(i, latents.squeeze(0), False) latents_preview = latents
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
callback(i, latents_preview[0], False)
latents_preview = None
x0 = latents.unsqueeze(0) x0 =latents.unbind(dim=0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w videos = self.vae.decode(x0, VAE_tile_size)
return output_video
if self.image_outputs:
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
else:
videos = videos[0] # return only first video
return videos

View File

@ -194,7 +194,9 @@ def pay_attention(
q = q.to(v.dtype) q = q.to(v.dtype)
k = k.to(v.dtype) k = k.to(v.dtype)
batch = len(q)
if len(k) != batch: k = k.expand(batch, -1, -1, -1)
if len(v) != batch: v = v.expand(batch, -1, -1, -1)
if attn == "chipmunk": if attn == "chipmunk":
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG

View File

@ -33,9 +33,10 @@ def sinusoidal_embedding_1d(dim, position):
def reshape_latent(latent, latent_frames): def reshape_latent(latent, latent_frames):
if latent_frames == latent.shape[0]: return latent.reshape(latent.shape[0], latent_frames, -1, latent.shape[-1] )
return latent
return latent.reshape(latent_frames, -1, latent.shape[-1] ) def restore_latent_shape(latent):
return latent.reshape(latent.shape[0], -1, latent.shape[-1] )
def identify_k( b: float, d: int, N: int): def identify_k( b: float, d: int, N: int):
@ -493,7 +494,7 @@ class WanAttentionBlock(nn.Module):
x_mod = reshape_latent(x_mod , latent_frames) x_mod = reshape_latent(x_mod , latent_frames)
x_mod *= 1 + e[1] x_mod *= 1 + e[1]
x_mod += e[0] x_mod += e[0]
x_mod = reshape_latent(x_mod , 1) x_mod = restore_latent_shape(x_mod)
if cam_emb != None: if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb) cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1) cam_emb = cam_emb.repeat(1, 2, 1)
@ -510,7 +511,7 @@ class WanAttentionBlock(nn.Module):
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames) x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[2]) x.addcmul_(y, e[2])
x, y = reshape_latent(x , 1), reshape_latent(y , 1) x, y = restore_latent_shape(x), restore_latent_shape(y)
del y del y
y = self.norm3(x) y = self.norm3(x)
y = y.to(attention_dtype) y = y.to(attention_dtype)
@ -542,7 +543,7 @@ class WanAttentionBlock(nn.Module):
y = reshape_latent(y , latent_frames) y = reshape_latent(y , latent_frames)
y *= 1 + e[4] y *= 1 + e[4]
y += e[3] y += e[3]
y = reshape_latent(y , 1) y = restore_latent_shape(y)
y = y.to(attention_dtype) y = y.to(attention_dtype)
ffn = self.ffn[0] ffn = self.ffn[0]
@ -562,7 +563,7 @@ class WanAttentionBlock(nn.Module):
y = y.to(dtype) y = y.to(dtype)
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames) x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5]) x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1) x, y = restore_latent_shape(x), restore_latent_shape(y)
if hints_processed is not None: if hints_processed is not None:
for hint, scale in zip(hints_processed, context_scale): for hint, scale in zip(hints_processed, context_scale):
@ -669,6 +670,8 @@ class VaceWanAttentionBlock(WanAttentionBlock):
hints[0] = None hints[0] = None
if self.block_id == 0: if self.block_id == 0:
c = self.before_proj(c) c = self.before_proj(c)
bz = x.shape[0]
if bz > c.shape[0]: c = c.repeat(bz, 1, 1 )
c += x c += x
c = super().forward(c, **kwargs) c = super().forward(c, **kwargs)
c_skip = self.after_proj(c) c_skip = self.after_proj(c)
@ -707,7 +710,7 @@ class Head(nn.Module):
x = reshape_latent(x , latent_frames) x = reshape_latent(x , latent_frames)
x *= (1 + e[1]) x *= (1 + e[1])
x += e[0] x += e[0]
x = reshape_latent(x , 1) x = restore_latent_shape(x)
x= x.to(self.head.weight.dtype) x= x.to(self.head.weight.dtype)
x = self.head(x) x = self.head(x)
return x return x
@ -1163,10 +1166,14 @@ class WanModel(ModelMixin, ConfigMixin):
last_x_idx = i last_x_idx = i
else: else:
# image source # image source
bz = len(x)
if y is not None: if y is not None:
x = torch.cat([x, y], dim=0) y = y.unsqueeze(0)
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1)
# embeddings # embeddings
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:] grid_sizes = x.shape[2:]
if chipmunk: if chipmunk:
x = x.unsqueeze(-1) x = x.unsqueeze(-1)
@ -1204,7 +1211,7 @@ class WanModel(ModelMixin, ConfigMixin):
) # b, dim ) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info: if self.inject_sample_info and fps!=None:
fps = torch.tensor(fps, dtype=torch.long, device=device) fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).to(e.dtype) fps_emb = self.fps_embedding(fps).to(e.dtype)
@ -1402,7 +1409,7 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[i] = self.unpatchify(x, grid_sizes) x_list[i] = self.unpatchify(x, grid_sizes)
del x del x
return [x[0].float() for x in x_list] return [x.float() for x in x_list]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""
@ -1427,7 +1434,10 @@ class WanModel(ModelMixin, ConfigMixin):
u = torch.einsum('fhwpqrc->cfphqwr', u) u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
out.append(u) out.append(u)
return out if len(x) == 1:
return out[0].unsqueeze(0)
else:
return torch.stack(out, 0)
def init_weights(self): def init_weights(self):
r""" r"""

View File

@ -333,7 +333,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1])) human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1])) human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device) back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype, device=human1.device)
max_indices = x_ref_attn_map.argmax(dim=0) max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1) normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
@ -351,7 +351,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
if self.qk_norm: if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k) encoder_k = self.add_k_norm(encoder_k)
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device) per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2 per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0) encoder_pos = torch.concat([per_frame]*N_t, dim=0)

View File

@ -272,6 +272,34 @@ def timestep_transform(
new_t = new_t * num_timesteps new_t = new_t * num_timesteps
return new_t return new_t
def parse_speakers_locations(speakers_locations):
bbox = {}
if speakers_locations is None or len(speakers_locations) == 0:
return None, ""
speakers = speakers_locations.split(" ")
if len(speakers) !=2:
error= "Two speakers locations should be defined"
return "", error
for i, speaker in enumerate(speakers):
location = speaker.strip().split(":")
if len(location) not in (2,4):
error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom"
return "", error
try:
good = False
location_float = [ float(val) for val in location]
good = all( 0 <= val <= 100 for val in location_float)
except:
pass
if not good:
error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100."
return "", error
if len(location_float) == 2:
location_float = [location_float[0], 0, location_float[1], 100]
bbox[f"human{i}"] = location_float
return bbox, ""
# construct human mask # construct human mask
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None): def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
@ -286,7 +314,9 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio" assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
background_mask = torch.zeros([src_h, src_w]) background_mask = torch.zeros([src_h, src_w])
for _, person_bbox in bbox.items(): for _, person_bbox in bbox.items():
x_min, y_min, x_max, y_max = person_bbox y_min, x_min, y_max, x_max = person_bbox
x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95)
x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100)
human_mask = torch.zeros([src_h, src_w]) human_mask = torch.zeros([src_h, src_w])
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
background_mask += human_mask background_mask += human_mask
@ -306,7 +336,7 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
human_masks = [human_mask1, human_mask2] human_masks = [human_mask1, human_mask2]
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
human_masks.append(background_mask) human_masks.append(background_mask)
# toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy())
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device) ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
# resize and centercrop for ref_target_masks # resize and centercrop for ref_target_masks
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))

View File

@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
_, seq_lens, heads, _ = visual_q.shape _, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype)
split_chunk = heads // split_num split_chunk = heads // split_num

View File

@ -5,7 +5,8 @@ import os
import os.path as osp import os.path as osp
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import torch.nn.functional as F import torch.nn.functional as F
import cv2
import tempfile
import imageio import imageio
import torch import torch
import decord import decord
@ -101,6 +102,29 @@ def get_video_frame(file_name, frame_no):
img = Image.fromarray(frame.numpy().astype(np.uint8)) img = Image.fromarray(frame.numpy().astype(np.uint8))
return img return img
def convert_image_to_video(image):
if image is None:
return None
# Convert PIL/numpy image to OpenCV format if needed
if isinstance(image, np.ndarray):
# Gradio images are typically RGB, OpenCV expects BGR
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
# Handle PIL Image
img_array = np.array(image)
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
height, width = img_bgr.shape[:2]
# Create temporary video file (auto-cleaned by Gradio)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height))
out.write(img_bgr)
out.release()
return temp_video.name
def resize_lanczos(img, h, w): def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS) img = img.resize((w,h), resample=Image.Resampling.LANCZOS)

325
wgp.py
View File

@ -16,7 +16,7 @@ import json
import wan import wan
from wan.utils import notification_sound from wan.utils import notification_sound
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
from wan.modules.attention import get_attention_modes, get_supported_attention_modes from wan.modules.attention import get_attention_modes, get_supported_attention_modes
@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.1" target_mmgp_version = "3.5.1"
WanGP_version = "6.7" WanGP_version = "7.0"
settings_version = 2.22 settings_version = 2.22
max_source_video_frames = 1000 max_source_video_frames = 1000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -171,6 +171,8 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Warning("Internal state error: Could not retrieve inputs for the model.") gr.Warning("Internal state error: Could not retrieve inputs for the model.")
queue = gen.get("queue", []) queue = gen.get("queue", [])
return get_queue_table(queue) return get_queue_table(queue)
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
model_type = get_base_model_type(model_type) model_type = get_base_model_type(model_type)
inputs["model_filename"] = model_filename inputs["model_filename"] = model_filename
@ -182,7 +184,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if frames_count > max_source_video_frames: if frames_count > max_source_video_frames:
gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated") gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated")
# return # return
for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask"]: for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask", "image_mask"]:
inputs[k] = None inputs[k] = None
inputs.update(edit_overrides) inputs.update(edit_overrides)
del gen["edit_video_source"], gen["edit_overrides"] del gen["edit_video_source"], gen["edit_overrides"]
@ -193,6 +195,13 @@ def process_prompt_and_add_tasks(state, model_choice):
if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"] if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"]
temporal_upsampling = inputs.get("temporal_upsampling","") temporal_upsampling = inputs.get("temporal_upsampling","")
if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"]
if image_outputs and len(temporal_upsampling) > 0:
gr.Info("Temporal Upsampling can not be used with an Image")
return
film_grain_intensity = inputs.get("film_grain_intensity",0)
film_grain_saturation = inputs.get("film_grain_saturation",0.5)
# if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"]
if film_grain_intensity >0: prompt += ["Film Grain"]
MMAudio_setting = inputs.get("MMAudio_setting",0) MMAudio_setting = inputs.get("MMAudio_setting",0)
seed = inputs.get("seed",None) seed = inputs.get("seed",None)
repeat_generation= inputs.get("repeat_generation",1) repeat_generation= inputs.get("repeat_generation",1)
@ -201,7 +210,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return return
if MMAudio_setting !=0: prompt += ["MMAudio"] if MMAudio_setting !=0: prompt += ["MMAudio"]
if len(prompt) == 0: if len(prompt) == 0:
gr.Info("You must choose at leat one Post Processing Method") gr.Info("You must choose at least one Post Processing Method")
return return
inputs["prompt"] = ", ".join(prompt) inputs["prompt"] = ", ".join(prompt)
add_video_task(**inputs) add_video_task(**inputs)
@ -247,7 +256,10 @@ def process_prompt_and_add_tasks(state, model_choice):
audio_guide = inputs["audio_guide"] audio_guide = inputs["audio_guide"]
audio_guide2 = inputs["audio_guide2"] audio_guide2 = inputs["audio_guide2"]
video_guide = inputs["video_guide"] video_guide = inputs["video_guide"]
image_guide = inputs["image_guide"]
video_mask = inputs["video_mask"] video_mask = inputs["video_mask"]
image_mask = inputs["image_mask"]
speakers_locations = inputs["speakers_locations"]
video_source = inputs["video_source"] video_source = inputs["video_source"]
frames_positions = inputs["frames_positions"] frames_positions = inputs["frames_positions"]
keep_frames_video_guide= inputs["keep_frames_video_guide"] keep_frames_video_guide= inputs["keep_frames_video_guide"]
@ -269,6 +281,13 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("Mag Cache maximum number of steps is 50") gr.Info("Mag Cache maximum number of steps is 50")
return return
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from wan.multitalk.multitalk import parse_speakers_locations
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
if len(error) > 0:
gr.Info(error)
return
if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
if "F" in video_prompt_type: if "F" in video_prompt_type:
@ -314,12 +333,16 @@ def process_prompt_and_add_tasks(state, model_choice):
audio_guide2 = None audio_guide2 = None
if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type): if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type):
if not "I" in video_prompt_type: if not "I" in video_prompt_type and not not "V" in video_prompt_type:
gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame that contains the two people one on each side ") gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side")
if "R" in audio_prompt_type and len(filter_letters(image_prompt_type, "VLG")) > 0 : if len(filter_letters(image_prompt_type, "VL")) > 0 :
if "R" in audio_prompt_type:
gr.Info("Remuxing is not yet supported if there is a video source") gr.Info("Remuxing is not yet supported if there is a video source")
audio_prompt_type= replace("R" ,"") audio_prompt_type= audio_prompt_type.replace("R" ,"")
if "A" in audio_prompt_type:
gr.Info("Creating an Audio track is not yet supported if there is a video source")
return
if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]:
if image_refs == None : if image_refs == None :
@ -342,17 +365,26 @@ def process_prompt_and_add_tasks(state, model_choice):
image_refs = None image_refs = None
if "V" in video_prompt_type: if "V" in video_prompt_type:
if video_guide == None: if video_guide is None and image_guide is None:
if image_outputs:
gr.Info("You must provide a Control Image")
else:
gr.Info("You must provide a Control Video") gr.Info("You must provide a Control Video")
return return
if "A" in video_prompt_type and not "U" in video_prompt_type: if "A" in video_prompt_type and not "U" in video_prompt_type:
if video_mask == None: if video_mask is None and image_mask is None:
if image_outputs:
gr.Info("You must provide a Image Mask")
else:
gr.Info("You must provide a Video Mask") gr.Info("You must provide a Video Mask")
return return
else: else:
video_mask = None video_mask = None
image_mask = None
if not "G" in video_prompt_type: if "G" in video_prompt_type:
gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ")
else:
denoising_strength = 1.0 denoising_strength = 1.0
_, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length)
@ -361,7 +393,9 @@ def process_prompt_and_add_tasks(state, model_choice):
return return
else: else:
video_guide = None video_guide = None
image_guide = None
video_mask = None video_mask = None
image_mask = None
keep_frames_video_guide = "" keep_frames_video_guide = ""
denoising_strength = 1.0 denoising_strength = 1.0
@ -416,10 +450,6 @@ def process_prompt_and_add_tasks(state, model_choice):
if "hunyuan_custom_custom_edit" in model_filename: if "hunyuan_custom_custom_edit" in model_filename:
if video_guide == None:
gr.Info("You must provide a Control Video")
return
if len(keep_frames_video_guide) > 0: if len(keep_frames_video_guide) > 0:
gr.Info("Filtering Frames with this model is not supported") gr.Info("Filtering Frames with this model is not supported")
return return
@ -440,7 +470,9 @@ def process_prompt_and_add_tasks(state, model_choice):
"audio_guide": audio_guide, "audio_guide": audio_guide,
"audio_guide2": audio_guide2, "audio_guide2": audio_guide2,
"video_guide": video_guide, "video_guide": video_guide,
"image_guide": image_guide,
"video_mask": video_mask, "video_mask": video_mask,
"image_mask": image_mask,
"video_source": video_source, "video_source": video_source,
"frames_positions": frames_positions, "frames_positions": frames_positions,
"keep_frames_video_source": keep_frames_video_source, "keep_frames_video_source": keep_frames_video_source,
@ -517,15 +549,15 @@ def process_prompt_and_add_tasks(state, model_choice):
return update_queue_data(queue) return update_queue_data(queue)
def get_preview_images(inputs): def get_preview_images(inputs):
inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "video_mask", "image_refs" ] inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ]
labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Video Mask", "Image Reference"] labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"]
start_image_data = None start_image_data = None
start_image_labels = [] start_image_labels = []
end_image_data = None end_image_data = None
end_image_labels = [] end_image_labels = []
for label, name in zip(labels,inputs_to_query): for label, name in zip(labels,inputs_to_query):
image= inputs.get(name, None) image= inputs.get(name, None)
if image != None: if image is not None:
image= [image] if not isinstance(image, list) else image.copy() image= [image] if not isinstance(image, list) else image.copy()
if start_image_data == None: if start_image_data == None:
start_image_data = image start_image_data = image
@ -645,7 +677,7 @@ def save_queue_action(state):
params_copy = task.get('params', {}).copy() params_copy = task.get('params', {}).copy()
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"]
for key in image_keys: for key in image_keys:
@ -821,7 +853,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
max_id_in_file = max(max_id_in_file, task_id_loaded) max_id_in_file = max(max_id_in_file, task_id_loaded)
params['state'] = state params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"]
loaded_pil_images = {} loaded_pil_images = {}
@ -1041,7 +1073,7 @@ def autosave_queue():
params_copy = task.get('params', {}).copy() params_copy = task.get('params', {}).copy()
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"]
for key in image_keys: for key in image_keys:
@ -1929,12 +1961,6 @@ def get_default_settings(model_type):
i2v = test_class_i2v(model_type) i2v = test_class_i2v(model_type)
defaults_filename = get_settings_file_name(model_type) defaults_filename = get_settings_file_name(model_type)
if not Path(defaults_filename).is_file(): if not Path(defaults_filename).is_file():
model_def = get_model_def(model_type)
if model_def != None:
ui_defaults = model_def["settings"]
if len(ui_defaults.get("prompt","")) == 0:
ui_defaults["prompt"]= get_default_prompt(i2v)
else:
ui_defaults = { ui_defaults = {
"prompt": get_default_prompt(i2v), "prompt": get_default_prompt(i2v),
"resolution": "1280x720" if "720" in model_type else "832x480", "resolution": "1280x720" if "720" in model_type else "832x480",
@ -2034,6 +2060,14 @@ def get_default_settings(model_type):
}) })
model_def = get_model_def(model_type)
if model_def != None:
ui_defaults_update = model_def["settings"]
ui_defaults.update(ui_defaults_update)
if len(ui_defaults.get("prompt","")) == 0:
ui_defaults["prompt"]= get_default_prompt(i2v)
with open(defaults_filename, "w", encoding="utf-8") as f: with open(defaults_filename, "w", encoding="utf-8") as f:
json.dump(ui_defaults, f, indent=4) json.dump(ui_defaults, f, indent=4)
else: else:
@ -2490,6 +2524,7 @@ def load_wan_model(model_filename, model_type, base_model_type, model_def, quant
checkpoint_dir="ckpts", checkpoint_dir="ckpts",
model_filename=model_filename, model_filename=model_filename,
model_type = model_type, model_type = model_type,
model_def = model_def,
base_model_type=base_model_type, base_model_type=base_model_type,
text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization),
quantizeTransformer = quantizeTransformer, quantizeTransformer = quantizeTransformer,
@ -2598,7 +2633,7 @@ def load_models(model_type):
save_quantized = False save_quantized = False
print("Need to provide a non quantized model to create a quantized model to be saved") print("Need to provide a non quantized model to create a quantized model to be saved")
if save_quantized and len(modules) > 0: if save_quantized and len(modules) > 0:
print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ('{model_types_no_module[0] if len(model_types_no_module)>0 else ''}' ?) to quantize and then add back the original 'modules' and 'architecture' entries.") print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ({modules}) to quantize and then add back the original 'modules' and 'architecture' entries.")
save_quantized = False save_quantized = False
quantizeTransformer = not save_quantized and model_def !=None and transformer_quantization in ("int8", "fp8") and model_def.get("auto_quantize", False) and not "quanto" in model_filename quantizeTransformer = not save_quantized and model_def !=None and transformer_quantization in ("int8", "fp8") and model_def.get("auto_quantize", False) and not "quanto" in model_filename
if quantizeTransformer and len(modules) > 0: if quantizeTransformer and len(modules) > 0:
@ -2931,8 +2966,10 @@ def refresh_gallery(state): #, msg
prompt = task["prompt"] prompt = task["prompt"]
params = task["params"] params = task["params"]
model_type = params["model_type"] model_type = params["model_type"]
model_type = get_base_model_type(model_type) base_model_type = get_base_model_type(model_type)
onemorewindow_visible = test_any_sliding_window(model_type) model_def = get_model_def(model_type)
is_image = model_def.get("image_outputs", False)
onemorewindow_visible = test_any_sliding_window(base_model_type) and not is_image
enhanced = False enhanced = False
if prompt.startswith("!enhanced!\n"): if prompt.startswith("!enhanced!\n"):
enhanced = True enhanced = True
@ -3047,7 +3084,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
pp_values= [] pp_values= []
pp_labels = [] pp_labels = []
extension = os.path.splitext(file_name)[-1] extension = os.path.splitext(file_name)[-1]
if not extension in [".mp4"]: if not has_video_file_extension(file_name):
img = Image.open(file_name) img = Image.open(file_name)
width, height = img.size width, height = img.size
configs = None configs = None
@ -3064,6 +3101,8 @@ def select_video(state, input_file_list, event_data: gr.EventData):
misc_labels += ["Model"] misc_labels += ["Model"]
video_temporal_upsampling = configs.get("temporal_upsampling", "") video_temporal_upsampling = configs.get("temporal_upsampling", "")
video_spatial_upsampling = configs.get("spatial_upsampling", "") video_spatial_upsampling = configs.get("spatial_upsampling", "")
video_film_grain_intensity = configs.get("film_grain_intensity", 0)
video_film_grain_saturation = configs.get("film_grain_saturation", 0.5)
video_MMAudio_setting = configs.get("MMAudio_setting", 0) video_MMAudio_setting = configs.get("MMAudio_setting", 0)
video_MMAudio_prompt = configs.get("MMAudio_prompt", "") video_MMAudio_prompt = configs.get("MMAudio_prompt", "")
video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "") video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "")
@ -3074,6 +3113,9 @@ def select_video(state, input_file_list, event_data: gr.EventData):
if len(video_temporal_upsampling) > 0: if len(video_temporal_upsampling) > 0:
pp_values += [ video_temporal_upsampling ] pp_values += [ video_temporal_upsampling ]
pp_labels += [ "Upsampling" ] pp_labels += [ "Upsampling" ]
if video_film_grain_intensity > 0:
pp_values += [ f"Intensity={video_film_grain_intensity}, Saturation={video_film_grain_saturation}" ]
pp_labels += [ "Film Grain" ]
if video_MMAudio_setting != 0: if video_MMAudio_setting != 0:
pp_values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ] pp_values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ]
pp_labels += [ "MMAudio" ] pp_labels += [ "MMAudio" ]
@ -3206,7 +3248,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
else: else:
html = get_default_video_info() html = get_default_video_info()
visible= len(file_list) > 0 visible= len(file_list) > 0
return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and is_image) return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image)
def expand_slist(slist, num_inference_steps ): def expand_slist(slist, num_inference_steps ):
new_slist= [] new_slist= []
inc = len(slist) / num_inference_steps inc = len(slist) / num_inference_steps
@ -3674,6 +3716,8 @@ def edit_video(
seed, seed,
temporal_upsampling, temporal_upsampling,
spatial_upsampling, spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting, MMAudio_setting,
MMAudio_prompt, MMAudio_prompt,
MMAudio_neg_prompt, MMAudio_neg_prompt,
@ -3694,6 +3738,7 @@ def edit_video(
if configs == None: configs = { "type" : get_model_record("Post Processing") } if configs == None: configs = { "type" : get_model_record("Post Processing") }
has_already_audio = False has_already_audio = False
audio_tracks = []
if MMAudio_setting == 0: if MMAudio_setting == 0:
audio_tracks = extract_audio_tracks(video_source) audio_tracks = extract_audio_tracks(video_source)
has_already_audio = len(audio_tracks) > 0 has_already_audio = len(audio_tracks) > 0
@ -3711,8 +3756,8 @@ def edit_video(
frames_count = min(frames_count, 1000) frames_count = min(frames_count, 1000)
sample = None sample = None
if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0: if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 or film_grain_intensity > 0:
send_cmd("progress", [0, get_latest_status(state,"Upsampling")]) send_cmd("progress", [0, get_latest_status(state,"Upsampling" if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 else "Adding Film Grain" )])
sample = get_resampled_video(video_source, 0, max_source_video_frames, fps) sample = get_resampled_video(video_source, 0, max_source_video_frames, fps)
sample = sample.float().div_(127.5).sub_(1.).permute(-1,0,1,2) sample = sample.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
frames_count = sample.shape[1] frames_count = sample.shape[1]
@ -3728,6 +3773,12 @@ def edit_video(
sample = perform_spatial_upsampling(sample, spatial_upsampling ) sample = perform_spatial_upsampling(sample, spatial_upsampling )
configs["spatial_upsampling"] = spatial_upsampling configs["spatial_upsampling"] = spatial_upsampling
if film_grain_intensity > 0:
from postprocessing.film_grain import add_film_grain
sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation)
configs["film_grain_intensity"] = film_grain_intensity
configs["film_grain_saturation"] = film_grain_saturation
any_mmaudio = MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and frames_count >=output_fps any_mmaudio = MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and frames_count >=output_fps
if any_mmaudio: download_mmaudio() if any_mmaudio: download_mmaudio()
@ -3834,16 +3885,19 @@ def generate_video(
image_refs, image_refs,
frames_positions, frames_positions,
video_guide, video_guide,
image_guide,
keep_frames_video_guide, keep_frames_video_guide,
denoising_strength, denoising_strength,
video_guide_outpainting, video_guide_outpainting,
video_mask, video_mask,
image_mask,
control_net_weight, control_net_weight,
control_net_weight2, control_net_weight2,
mask_expand, mask_expand,
audio_guide, audio_guide,
audio_guide2, audio_guide2,
audio_prompt_type, audio_prompt_type,
speakers_locations,
sliding_window_size, sliding_window_size,
sliding_window_overlap, sliding_window_overlap,
sliding_window_overlap_noise, sliding_window_overlap_noise,
@ -3851,6 +3905,8 @@ def generate_video(
remove_background_images_ref, remove_background_images_ref,
temporal_upsampling, temporal_upsampling,
spatial_upsampling, spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting, MMAudio_setting,
MMAudio_prompt, MMAudio_prompt,
MMAudio_neg_prompt, MMAudio_neg_prompt,
@ -3871,11 +3927,17 @@ def generate_video(
model_filename, model_filename,
mode, mode,
): ):
def remove_temp_filenames(temp_filenames_list):
for temp_filename in temp_filenames_list:
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
global wan_model, offloadobj, reload_needed, save_path global wan_model, offloadobj, reload_needed, save_path
gen = get_gen_info(state) gen = get_gen_info(state)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
if mode == "edit": if mode == "edit":
edit_video(send_cmd, state, video_source, seed, temporal_upsampling, spatial_upsampling, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation) edit_video(send_cmd, state, video_source, seed, temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation)
return return
with lock: with lock:
file_list = gen["file_list"] file_list = gen["file_list"]
@ -3884,6 +3946,23 @@ def generate_video(
model_def = get_model_def(model_type) model_def = get_model_def(model_type)
is_image = model_def.get("image_outputs", False) is_image = model_def.get("image_outputs", False)
if is_image:
batch_size = video_length
video_length = 1
else:
batch_size = 1
temp_filenames_list = []
if image_guide is not None and isinstance(image_guide, Image.Image):
video_guide = convert_image_to_video(image_guide)
temp_filenames_list.append(video_guide)
image_guide = None
if image_mask is not None and isinstance(image_mask, Image.Image):
video_mask = convert_image_to_video(image_mask)
temp_filenames_list.append(video_mask)
image_mask = None
fit_canvas = server_config.get("fit_canvas", 0) fit_canvas = server_config.get("fit_canvas", 0)
@ -3926,7 +4005,6 @@ def generate_video(
trans = get_transformer_model(wan_model) trans = get_transformer_model(wan_model)
audio_sampling_rate = 16000 audio_sampling_rate = 16000
temp_filename = None
base_model_type = get_base_model_type(model_type) base_model_type = get_base_model_type(model_type)
prompts = prompt.split("\n") prompts = prompt.split("\n")
@ -4012,6 +4090,11 @@ def generate_video(
multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"]
flux_dev_kontext = base_model_type in ["flux_dev_kontext"] flux_dev_kontext = base_model_type in ["flux_dev_kontext"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from wan.multitalk.multitalk import parse_speakers_locations
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
else:
speakers_bboxes = None
if "L" in image_prompt_type: if "L" in image_prompt_type:
if len(file_list)>0: if len(file_list)>0:
video_source = file_list[-1] video_source = file_list[-1]
@ -4268,7 +4351,7 @@ def generate_video(
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
if reuse_frames > 0: if reuse_frames > 0:
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {} refresh_preview = {"image_guide" : None, "image_mask" : None}
if fantasy: if fantasy:
window_latent_start_frame = (window_start_frame ) // latent_size window_latent_start_frame = (window_start_frame ) // latent_size
window_latent_size= (current_video_length - 1) // latent_size + 1 window_latent_size= (current_video_length - 1) // latent_size + 1
@ -4426,7 +4509,8 @@ def generate_video(
input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video, input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
target_camera= target_camera, target_camera= target_camera,
frame_num=current_video_length if is_image else (current_video_length // latent_size)* latent_size + 1, frame_num= (current_video_length // latent_size)* latent_size + 1,
batch_size = batch_size,
height = height, height = height,
width = width, width = width,
fit_into_canvas = fit_canvas == 1, fit_into_canvas = fit_canvas == 1,
@ -4469,11 +4553,13 @@ def generate_video(
NAG_scale = NAG_scale, NAG_scale = NAG_scale,
NAG_tau = NAG_tau, NAG_tau = NAG_tau,
NAG_alpha = NAG_alpha, NAG_alpha = NAG_alpha,
speakers_bboxes =speakers_bboxes,
offloadobj = offloadobj, offloadobj = offloadobj,
) )
except Exception as e: except Exception as e:
if temp_filename!= None and os.path.isfile(temp_filename): if len(control_audio_tracks) > 0:
os.remove(temp_filename) cleanup_temp_audio_files(control_audio_tracks)
remove_temp_filenames(temp_filenames_list)
offloadobj.unload_all() offloadobj.unload_all()
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
# if compile: # if compile:
@ -4569,7 +4655,9 @@ def generate_video(
if len(spatial_upsampling) > 0: if len(spatial_upsampling) > 0:
sample = perform_spatial_upsampling(sample, spatial_upsampling ) sample = perform_spatial_upsampling(sample, spatial_upsampling )
if film_grain_intensity> 0:
from postprocessing.film_grain import add_film_grain
sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation)
if sliding_window : if sliding_window :
if frames_already_processed == None: if frames_already_processed == None:
frames_already_processed = sample frames_already_processed = sample
@ -4675,8 +4763,8 @@ def generate_video(
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
if len(control_audio_tracks) > 0: if len(control_audio_tracks) > 0:
cleanup_temp_audio_files(control_audio_tracks) cleanup_temp_audio_files(control_audio_tracks)
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) remove_temp_filenames(temp_filenames_list)
def prepare_generate_video(state): def prepare_generate_video(state):
@ -5529,7 +5617,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if "lset_name" in inputs: if "lset_name" in inputs:
inputs.pop("lset_name") inputs.pop("lset_name")
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "audio_guide2"] unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "image_guide", "video_source", "video_mask", "image_mask", "audio_guide", "audio_guide2"]
for k in unsaved_params: for k in unsaved_params:
inputs.pop(k) inputs.pop(k)
if model_filename == None: model_filename = state["model_filename"] if model_filename == None: model_filename = state["model_filename"]
@ -5629,28 +5717,36 @@ def video_to_source_video(state, input_file_list, choice):
gr.Info("Selected Video was copied to Source Video input") gr.Info("Selected Video was copied to Source Video input")
return file_list[choice] return file_list[choice]
def image_to_ref_image(state, input_file_list, choice, target, target_name): def image_to_ref_image_add(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list) file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info(f"Selected Image was copied to {target_name}") gr.Info(f"Selected Image was added to {target_name}")
if target == None: if target == None:
target =[] target =[]
target.append( file_list[choice]) target.append( file_list[choice])
return target return target
def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info(f"Selected Image was copied to {target_name}")
return file_list[choice]
def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation):
def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation):
gen = get_gen_info(state) gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, input_file_list) file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) : if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
return gr.update(), gr.update() return gr.update(), gr.update(), gr.update()
if not file_list[choice].endswith(".mp4"): if not file_list[choice].endswith(".mp4"):
gr.Info("Post processing is only available with Videos") gr.Info("Post processing is only available with Videos")
return gr.update(), gr.update() return gr.update(), gr.update(), gr.update()
overrides = { overrides = {
"temporal_upsampling":PP_temporal_upsampling, "temporal_upsampling":PP_temporal_upsampling,
"spatial_upsampling":PP_spatial_upsampling, "spatial_upsampling":PP_spatial_upsampling,
"film_grain_intensity": PP_film_grain_intensity,
"film_grain_saturation": PP_film_grain_saturation,
"MMAudio_setting" : PP_MMAudio_setting, "MMAudio_setting" : PP_MMAudio_setting,
"MMAudio_prompt" : PP_MMAudio_prompt, "MMAudio_prompt" : PP_MMAudio_prompt,
"MMAudio_neg_prompt": PP_MMAudio_neg_prompt, "MMAudio_neg_prompt": PP_MMAudio_neg_prompt,
@ -5682,6 +5778,14 @@ def eject_video_from_gallery(state, input_file_list, choice):
choice = min(choice, len(file_list)) choice = min(choice, len(file_list))
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def has_video_file_extension(filename):
extension = os.path.splitext(filename)[-1]
return extension in [".mp4"]
def has_image_file_extension(filename):
extension = os.path.splitext(filename)[-1]
return extension in [".jpeg", ".jpg", ".png", ".bmp", ".tiff"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load): def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state) gen = get_gen_info(state)
if files_to_load == None: if files_to_load == None:
@ -5693,10 +5797,15 @@ def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
for file_path in files_to_load: for file_path in files_to_load:
file_settings, _ = get_settings_from_file(state, file_path, False, False, False) file_settings, _ = get_settings_from_file(state, file_path, False, False, False)
if file_settings == None: if file_settings == None:
try:
fps, width, height, frames_count = get_video_info(file_path)
except:
fps = 0 fps = 0
try:
if has_video_file_extension(file_path):
fps, width, height, frames_count = get_video_info(file_path)
elif has_image_file_extension(file_path):
width, height = Image.open(file_path).size
fps = 1
except:
pass
if fps == 0: if fps == 0:
invalid_files_count += 1 invalid_files_count += 1
continue continue
@ -5878,15 +5987,18 @@ def save_inputs(
image_refs, image_refs,
frames_positions, frames_positions,
video_guide, video_guide,
image_guide,
keep_frames_video_guide, keep_frames_video_guide,
denoising_strength, denoising_strength,
video_mask, video_mask,
image_mask,
control_net_weight, control_net_weight,
control_net_weight2, control_net_weight2,
mask_expand, mask_expand,
audio_guide, audio_guide,
audio_guide2, audio_guide2,
audio_prompt_type, audio_prompt_type,
speakers_locations,
sliding_window_size, sliding_window_size,
sliding_window_overlap, sliding_window_overlap,
sliding_window_overlap_noise, sliding_window_overlap_noise,
@ -5894,6 +6006,8 @@ def save_inputs(
remove_background_images_ref, remove_background_images_ref,
temporal_upsampling, temporal_upsampling,
spatial_upsampling, spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting, MMAudio_setting,
MMAudio_prompt, MMAudio_prompt,
MMAudio_neg_prompt, MMAudio_neg_prompt,
@ -6097,7 +6211,7 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux):
def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources): def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources):
audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB") audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB")
audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources)
return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type) return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type))
def refresh_image_prompt_type(state, image_prompt_type): def refresh_image_prompt_type(state, image_prompt_type):
any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0
@ -6110,19 +6224,26 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
vace= test_vace_module(state["model_type"]) vace= test_vace_module(state["model_type"])
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask):
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA") video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type visible= "A" in video_prompt_type
return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) model_type = state["model_type"]
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible )
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide): def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "PDSLCMGUV") video_prompt_type = del_in_sequence(video_prompt_type, "PDSLCMGUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
model_type = state["model_type"]
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
vace= test_vace_module(state["model_type"]) vace= test_vace_module(state["model_type"])
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): # def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] # video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
@ -6223,6 +6344,8 @@ def get_resolution_choices(current_resolution_choice):
if resolution_choices == None: if resolution_choices == None:
resolution_choices=[ resolution_choices=[
# 1080p # 1080p
("1920x1088 (21:9, 1080p)", "1920x1088"),
("1088x1920 (9:21, 1080p)", "1088x1920"),
("1920x832 (21:9, 1080p)", "1920x832"), ("1920x832 (21:9, 1080p)", "1920x832"),
("832x1920 (9:21, 1080p)", "832x1920"), ("832x1920 (9:21, 1080p)", "832x1920"),
# 720p # 720p
@ -6390,7 +6513,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if vace: if vace:
image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= True , scale= 3) image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3)
image_start = gr.Gallery(visible = False) image_start = gr.Gallery(visible = False)
image_end = gr.Gallery(visible = False) image_end = gr.Gallery(visible = False)
@ -6480,12 +6603,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
any_control_video = True any_control_video = True
any_control_image = image_outputs
with gr.Row(): with gr.Row():
if t2v: if t2v:
video_prompt_type_video_guide = gr.Dropdown( video_prompt_type_video_guide = gr.Dropdown(
choices=[ choices=[
("Use Text Prompt Only", ""), ("Use Text Prompt Only", ""),
("Video to Video guided by Text Prompt", "GUV"), ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"),
], ],
value=filter_letters(video_prompt_type_value, "GUV"), value=filter_letters(video_prompt_type_value, "GUV"),
label="Video to Video", scale = 2, show_label= False, visible= True label="Video to Video", scale = 2, show_label= False, visible= True
@ -6493,8 +6617,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
elif vace: elif vace:
video_prompt_type_video_guide = gr.Dropdown( video_prompt_type_video_guide = gr.Dropdown(
choices=[ choices=[
("No Control Video", ""), ("No Control Image" if image_outputs else "No Control Video", ""),
("Keep Control Video Unchanged", "UV"), ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"),
("Transfer Human Motion", "PV"), ("Transfer Human Motion", "PV"),
("Transfer Depth", "DV"), ("Transfer Depth", "DV"),
("Transfer Shapes", "SV"), ("Transfer Shapes", "SV"),
@ -6510,19 +6634,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Transfer Shapes & Flow", "SLV"), ("Transfer Shapes & Flow", "SLV"),
], ],
value=filter_letters(video_prompt_type_value, "PDSLCMGUV"), value=filter_letters(video_prompt_type_value, "PDSLCMGUV"),
label="Control Video Process", scale = 2, visible= True, show_label= True, label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True,
) )
elif hunyuan_video_custom_edit: elif hunyuan_video_custom_edit:
video_prompt_type_video_guide = gr.Dropdown( video_prompt_type_video_guide = gr.Dropdown(
choices=[ choices=[
("Inpaint Control Video", "MV"), ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"),
("Transfer Human Motion", "PMV"), ("Transfer Human Motion", "PMV"),
], ],
value=filter_letters(video_prompt_type_value, "PDSLCMUV"), value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
label="Video to Video", scale = 3, visible= True, show_label= True, label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
) )
else: else:
any_control_video = False any_control_video = False
any_control_image = False
video_prompt_type_video_guide = gr.Dropdown(visible= False) video_prompt_type_video_guide = gr.Dropdown(visible= False)
# video_prompt_video_guide_trigger = gr.Text(visible=False, value="") # video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
@ -6578,16 +6703,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
visible = False, visible = False,
label="Start / Reference Images", scale = 2 label="Start / Reference Images", scale = 2
) )
image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None))
video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= (not image_outputs) and "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group(): with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Background or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -6595,8 +6721,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_bottom = gr.Slider(0, 100, value= video_guide_outpainting_list[1], step=5, label="Bottom %", show_reset_button= False) video_guide_outpainting_bottom = gr.Slider(0, 100, value= video_guide_outpainting_list[1], step=5, label="Bottom %", show_reset_button= False)
video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False)
video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False)
any_image_mask = image_outputs and vace
video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None))
video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None))
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar
@ -6630,7 +6757,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB"), ("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB"),
], ],
value= filter_letters(audio_prompt_type_value, "XCPAB"), value= filter_letters(audio_prompt_type_value, "XCPAB"),
label="Voices: if there are multiple People the first is assumed to be to the Left and the second one to the Right", scale = 3, visible = multitalk label="Voices", scale = 3, visible = multitalk
) )
else: else:
audio_prompt_type_sources = gr.Dropdown( choices= [""], value = "", visible=False) audio_prompt_type_sources = gr.Dropdown( choices= [""], value = "", visible=False)
@ -6638,6 +6765,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row(visible = any_audio_voices_support) as audio_guide_row: with gr.Row(visible = any_audio_voices_support) as audio_guide_row:
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value ) audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value )
audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value ) audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value )
with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) ) as speakers_locations_row:
speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True)
advanced_prompt = advanced_ui advanced_prompt = advanced_ui
prompt_vars=[] prompt_vars=[]
@ -6694,7 +6823,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
) )
with gr.Row(): with gr.Row():
if image_outputs: if image_outputs:
video_length = gr.Slider(1, 16, value=ui_defaults.get("video_length", 1), step=1, label="Number of Images to Generate", visible = flux_dev_kontext) video_length = gr.Slider(1, 16, value=ui_defaults.get("video_length", 1), step=1, label="Number of Images to Generate", visible = True)
elif recammaster: elif recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
else: else:
@ -6702,7 +6831,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(min_frames, 737 if test_any_sliding_window(base_model_type) else 337, value=ui_defaults.get( video_length = gr.Slider(min_frames, 737 if test_any_sliding_window(base_model_type) else 337, value=ui_defaults.get(
"video_length", 81 if get_model_family(base_model_type)=="wan" else 97), "video_length", 81 if get_model_family(base_model_type)=="wan" else 97),
step=frames_step, label=f"Number of frames ({fps} = 1s)", interactive= True) step=frames_step, label=f"Number of frames ({fps} = 1s)", visible = True, interactive= True)
with gr.Row(visible = not lock_inference_steps) as inference_steps_row: with gr.Row(visible = not lock_inference_steps) as inference_steps_row:
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True)
@ -6790,12 +6919,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
) )
skip_steps_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("skip_steps_start_step_perc",0), step=1, label="Skip Steps starting moment in % of generation") skip_steps_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("skip_steps_start_step_perc",0), step=1, label="Skip Steps starting moment in % of generation")
with gr.Tab("Upsampling"): with gr.Tab("Post Processing"):
with gr.Column(): with gr.Column():
gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>") gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , element_class= None, max_height= None): def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grain_intensity, film_grain_saturation, element_class= None, max_height= None):
temporal_upsampling = gr.Dropdown( temporal_upsampling = gr.Dropdown(
choices=[ choices=[
("Disabled", ""), ("Disabled", ""),
@ -6803,7 +6932,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Rife x4 frames/s", "rife4"), ("Rife x4 frames/s", "rife4"),
], ],
value=temporal_upsampling, value=temporal_upsampling,
visible=not image_outputs, visible=True,
scale = 1, scale = 1,
label="Temporal Upsampling", label="Temporal Upsampling",
elem_classes= element_class elem_classes= element_class
@ -6822,8 +6951,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
elem_classes= element_class elem_classes= element_class
# max_height = max_height # max_height = max_height
) )
return temporal_upsampling, spatial_upsampling
temporal_upsampling, spatial_upsampling = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", "")) with gr.Row():
film_grain_intensity = gr.Slider(0, 1, value=film_grain_intensity, step=0.01, label="Film Grain Intensity (0 = disabled)")
film_grain_saturation = gr.Slider(0.0, 1, value=film_grain_saturation, step=0.01, label="Film Grain Saturation")
return temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation
temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", ""), ui_defaults.get("film_grain_intensity", 0), ui_defaults.get("film_grain_saturation", 0.5))
with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type) and not image_outputs) as mmaudio_tab: with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type) and not image_outputs) as mmaudio_tab:
with gr.Column(): with gr.Column():
@ -7024,18 +7158,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row(**default_visibility) as image_buttons_row: with gr.Row(**default_visibility) as image_buttons_row:
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image ) video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image )
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image)
video_info_to_image_guide_btn = gr.Button("To Control Image", size ="sm", visible = any_control_image )
video_info_to_image_mask_btn = gr.Button("To Mask Image", size ="sm", visible = any_image_mask)
video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image) video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image)
video_info_eject_image_btn = gr.Button("Eject Image", size ="sm") video_info_eject_image_btn = gr.Button("Eject Image", size ="sm")
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
with gr.Group(elem_classes= "postprocess"): with gr.Group(elem_classes= "postprocess"):
with gr.Column(): with gr.Column():
PP_temporal_upsampling, PP_spatial_upsampling = gen_upsampling_dropdowns("", "", element_class ="postprocess") PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation = gen_upsampling_dropdowns("", "", 0, 0.5, element_class ="postprocess")
with gr.Column() as PP_MMAudio_col: with gr.Column(visible = server_config.get("mmaudio_enabled", 0) == 1) as PP_MMAudio_col:
PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns( 0, "" , "", None, element_class ="postprocess" ) PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns( 0, "" , "", None, element_class ="postprocess" )
PP_MMAudio_seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)") PP_MMAudio_seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate") PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate")
with gr.Row():
video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True) video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True)
video_info_eject_video2_btn = gr.Button("Eject Video", size ="sm", visible=True)
with gr.Tab("Add Videos / Images", id= "video_add"): with gr.Tab("Add Videos / Images", id= "video_add"):
files_to_load = gr.Files(label= "Files to Load in Gallery", height=120) files_to_load = gr.Files(label= "Files to Load in Gallery", height=120)
with gr.Row(): with gr.Row():
@ -7092,8 +7229,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
NAG_col] # presets_column, NAG_col, speakers_locations_row] # presets_column,
if update_form: if update_form:
locals_dict = locals() locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -7104,12 +7241,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
last_choice = gr.Number(value =-1, interactive= False, visible= False) last_choice = gr.Number(value =-1, interactive= False, visible= False)
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col]) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt])
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
@ -7119,7 +7256,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab] ) output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab], trigger_mode="multiple")
preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview])
def refresh_status_async(state, progress=gr.Progress()): def refresh_status_async(state, progress=gr.Progress()):
@ -7175,13 +7312,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_choice, refresh_form_trigger]) ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_choice, refresh_form_trigger])
video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] )
gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_video2_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] )
video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] ) video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] )
video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] )
video_info_to_start_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] )
video_info_to_end_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] )
video_info_to_reference_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] )
video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] )
video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] )
video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] )
save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
@ -7405,7 +7544,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
) )
return ( state, loras_choices, lset_name, state, return ( state, loras_choices, lset_name, state,
video_guide, video_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
) )
@ -8220,12 +8359,12 @@ def create_ui():
header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True) header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True)
with gr.Row(): with gr.Row():
( state, loras_choices, lset_name, state, ( state, loras_choices, lset_name, state,
video_guide, video_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
) = generate_video_tab(model_choice=model_choice, header=header, main = main) ) = generate_video_tab(model_choice=model_choice, header=header, main = main)
with gr.Tab("Guides", id="info") as info_tab: with gr.Tab("Guides", id="info") as info_tab:
generate_info_tab() generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
matanyone_app.display(main_tabs, tab_state, model_choice, video_guide, video_mask, image_refs) matanyone_app.display(main_tabs, tab_state, model_choice, video_guide, image_guide, video_mask, image_mask, image_refs)
if not args.lock_config: if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab: with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_download_tab(lset_name, loras_choices, state) generate_download_tab(lset_name, loras_choices, state)