mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Flux Kontext and more
This commit is contained in:
parent
37f41804a6
commit
64c59c15d9
20
README.md
20
README.md
@ -20,6 +20,26 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates
|
||||
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
|
||||
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
|
||||
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
|
||||
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
|
||||
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
|
||||
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
|
||||
|
||||
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
|
||||
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
|
||||
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
|
||||
|
||||
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
|
||||
|
||||
Also in the news:
|
||||
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
|
||||
- *Film Grain* post processing to add a vintage look at your video
|
||||
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
|
||||
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
|
||||
|
||||
|
||||
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
|
||||
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
|
||||
|
||||
|
||||
@ -2,18 +2,15 @@
|
||||
"model": {
|
||||
"name": "Flux Dev Kontext 12B",
|
||||
"architecture": "flux_dev_kontext",
|
||||
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that the output resolution is modified by Flux Kontext and may not be what you requested.",
|
||||
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.",
|
||||
"URLs": [
|
||||
"c:/temp/kontext/flux1_kontext_dev_bf16.safetensors",
|
||||
"c:/temp/kontext/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"URLs2": [
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
},
|
||||
"prompt": "add a hat",
|
||||
"resolution": "1280x720",
|
||||
"video_length": "1"
|
||||
"video_length": 1
|
||||
}
|
||||
|
||||
|
||||
13
defaults/t2i.json
Normal file
13
defaults/t2i.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.1 text2image 14B",
|
||||
"architecture": "t2v",
|
||||
"description": "The original Wan Text 2 Video model configured to generate an image instead of a video.",
|
||||
"image_outputs": true,
|
||||
"URLs": "t2v"
|
||||
},
|
||||
"video_length": 1,
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
"seed": -1,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 1,
|
||||
"flow_shift": 5,
|
||||
"flow_shift": 2,
|
||||
"embedded_guidance_scale": 6,
|
||||
"repeat_generation": 1,
|
||||
"multi_images_gen_type": 0,
|
||||
|
||||
16
defaults/vace_14B_fusionix_t2i.json
Normal file
16
defaults/vace_14B_fusionix_t2i.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Vace FusioniX image2image 14B",
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"image_outputs": true,
|
||||
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
||||
"URLs": "t2v_fusionix"
|
||||
},
|
||||
"resolution": "1280x720",
|
||||
"guidance_scale": 1,
|
||||
"num_inference_steps": 10,
|
||||
"video_length": 1
|
||||
}
|
||||
@ -2,22 +2,30 @@
|
||||
|
||||
A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models.
|
||||
|
||||
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP, however you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
|
||||
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP. However you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
|
||||
|
||||
WanGP finetune system can be also used to tweak default models : for instance you can add on top of an existing model some loras that will be always applied transparently.
|
||||
|
||||
Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV
|
||||
|
||||
All the finetunes definitions files should be stored in the *finetunes/* subfolder.
|
||||
|
||||
Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes.
|
||||
|
||||
## Create a new Finetune Model Definition
|
||||
All the finetune models definitions are json files stored in the **finetunes** sub folder. All the corresponding finetune model weights will be stored in the *ckpts* subfolder and will sit next to the base models.
|
||||
|
||||
WanGP comes with a few prebuilt finetune models that you can use as starting points and to get an idea of the structure of the definition file.
|
||||
|
||||
## Create a new Finetune Model Definition
|
||||
All the finetune models definitions are json files stored in the **finetunes/** sub folder. All the corresponding finetune model weights when they are downloaded will be stored in the *ckpts/* subfolder and will sit next to the base models.
|
||||
|
||||
All the models used by WanGP are also described using the finetunes json format and can be found in the **defaults/** subfolder. Please don’t modify any file in the **defaults/** folder.
|
||||
|
||||
However you can use these files as starting points for new definition files and to get an idea of the structure of a definition file. If you want to change how a base model is handled (title, default settings, path to model weights, …) you may override any property of the default finetunes definition file by creating a new file in the finetunes folder with the same name. Everything will happen as if the two models will be merged property by property with a higher priority given to the finetunes model definition.
|
||||
|
||||
A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...).
|
||||
|
||||
You can obtain a settings file in several ways:
|
||||
- In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models)
|
||||
- From the user interface, go to the base model and click **export settings**
|
||||
- From the user interface, select the base model for which you want to create a finetune and click **export settings**
|
||||
|
||||
Here are steps:
|
||||
1) Create a *settings file*
|
||||
@ -26,39 +34,54 @@ Here are steps:
|
||||
4) Restart WanGP
|
||||
|
||||
## Architecture Models Ids
|
||||
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are Architecture Ids:
|
||||
- *t2v*: Wan 2.1 Video text 2
|
||||
- *i2v*: Wan 2.1 Video image 2 480p
|
||||
- *i2v_720p*: Wan 2.1 Video image 2 720p
|
||||
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are some Architecture Ids:
|
||||
- *t2v*: Wan 2.1 Video text 2 video
|
||||
- *i2v*: Wan 2.1 Video image 2 video 480p and 720p
|
||||
- *vace_14B*: Wan 2.1 Vace 14B
|
||||
- *hunyuan*: Hunyuan Video text 2 video
|
||||
- *hunyuan_i2v*: Hunyuan Video image 2 video
|
||||
|
||||
Any file name in the defaults subfolder (without the json extension) corresponds to an architecture id.
|
||||
|
||||
Please note that weights of some architectures correspond to a combination of weight of a one architecture which are completed by the weights of one more or modules.
|
||||
|
||||
A module is a set a weights that are insufficient to be model by itself but that can be added to an existing model to extend its capabilities.
|
||||
|
||||
For instance if one adds a module *vace_14B* on top of a model with architecture *t2v* one gets get a model with the *vace_14B* architecture. Here *vace_14B* stands for both an architecture name and a module name. The module system allows you to reuse shared weights between models.
|
||||
|
||||
|
||||
## The Model Subtree
|
||||
- *name* : name of the finetune used to select
|
||||
- *architecture* : architecture Id of the base model of the finetune (see previous section)
|
||||
- *description*: description of the finetune that will appear at the top
|
||||
- *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing.
|
||||
- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. So far the only module supported is Vace 14B (its id is *vace_14B*). For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module.
|
||||
- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module.
|
||||
- *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance)
|
||||
-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above.
|
||||
- *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model
|
||||
-*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it.
|
||||
-*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame.
|
||||
|
||||
In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse.
|
||||
|
||||
For example let’s say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file.
|
||||
|
||||
Example of **model** subtree
|
||||
```
|
||||
"model":
|
||||
{
|
||||
"name": "Wan text2video FusioniX 14B",
|
||||
"architecture" : "t2v",
|
||||
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"model":
|
||||
{
|
||||
"name": "Wan text2video FusioniX 14B",
|
||||
"architecture" : "t2v",
|
||||
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"preload_URLs": [
|
||||
],
|
||||
"auto_quantize": true
|
||||
},
|
||||
"auto_quantize": true
|
||||
},
|
||||
```
|
||||
|
||||
## Finetune Model Naming Convention
|
||||
|
||||
@ -6,18 +6,19 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a
|
||||
|
||||
Loras are organized in different folders based on the model they're designed for:
|
||||
|
||||
### Text-to-Video Models
|
||||
### Wan Text-to-Video Models
|
||||
- `loras/` - General t2v loras
|
||||
- `loras/1.3B/` - Loras specifically for 1.3B models
|
||||
- `loras/14B/` - Loras specifically for 14B models
|
||||
|
||||
### Image-to-Video Models
|
||||
### Wan Image-to-Video Models
|
||||
- `loras_i2v/` - Image-to-video loras
|
||||
|
||||
### Other Models
|
||||
- `loras_hunyuan/` - Hunyuan Video t2v loras
|
||||
- `loras_hunyuan_i2v/` - Hunyuan Video i2v loras
|
||||
- `loras_ltxv/` - LTX Video loras
|
||||
- `loras_flux/` - Flux loras
|
||||
|
||||
## Custom Lora Directory
|
||||
|
||||
@ -64,7 +65,7 @@ For dynamic effects over generation steps, use comma-separated values:
|
||||
|
||||
## Lora Presets
|
||||
|
||||
Presets are combinations of loras with predefined multipliers and prompts.
|
||||
Lora Presets are combinations of loras with predefined multipliers and prompts.
|
||||
|
||||
### Creating Presets
|
||||
1. Configure your loras and multipliers
|
||||
@ -95,16 +96,36 @@ WanGP supports multiple lora formats:
|
||||
- **Replicate** format
|
||||
- **Standard PyTorch** (.pt, .pth)
|
||||
|
||||
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
|
||||
|
||||
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models
|
||||
## Loras Accelerators
|
||||
Most Loras are used to apply a specific style or to alter the content of the output of the generated video.
|
||||
However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video.
|
||||
|
||||
You will find most *Loras Accelerators* here:
|
||||
https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators
|
||||
|
||||
### Setup Instructions
|
||||
1. Download the Lora:
|
||||
```
|
||||
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors
|
||||
```
|
||||
2. Place in your `loras/` directory
|
||||
1. Download the Lora
|
||||
2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora
|
||||
|
||||
## FusioniX (or FusionX) Lora
|
||||
If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v
|
||||
|
||||
### Usage
|
||||
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
|
||||
2. Enable Advanced Mode
|
||||
3. In Advanced Generation Tab:
|
||||
- Set Guidance Scale = 1
|
||||
- Set Shift Scale = 2
|
||||
4. In Advanced Lora Tab:
|
||||
- Select CausVid Lora
|
||||
- Set multiplier to 1
|
||||
5. Set generation steps from 8-10
|
||||
6. Generate!
|
||||
|
||||
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
|
||||
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models
|
||||
You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors*
|
||||
|
||||
### Usage
|
||||
1. Select a Wan t2v or i2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
|
||||
@ -118,17 +139,10 @@ Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distil
|
||||
5. Set generation steps to 2-8
|
||||
6. Generate!
|
||||
|
||||
|
||||
## CausVid Lora (Video Generation Accelerator)
|
||||
|
||||
CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement.
|
||||
|
||||
### Setup Instructions
|
||||
1. Download the CausVid Lora:
|
||||
```
|
||||
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors
|
||||
```
|
||||
2. Place in your `loras/` directory
|
||||
|
||||
### Usage
|
||||
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
|
||||
2. Enable Advanced Mode
|
||||
@ -149,25 +163,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe
|
||||
*Note: Lower steps = lower quality (especially motion)*
|
||||
|
||||
|
||||
|
||||
## AccVid Lora (Video Generation Accelerator)
|
||||
|
||||
AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1).
|
||||
|
||||
### Setup Instructions
|
||||
1. Download the AccVid Lora:
|
||||
|
||||
- for t2v models:
|
||||
```
|
||||
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors
|
||||
```
|
||||
|
||||
- for i2v models:
|
||||
```
|
||||
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors
|
||||
```
|
||||
|
||||
2. Place in your `loras/` directory or `loras_i2v/` directory
|
||||
|
||||
### Usage
|
||||
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model
|
||||
@ -268,6 +267,7 @@ In the video, a man is presented. The man is in a city and looks at his watch.
|
||||
--lora-dir-hunyuan path # Path to Hunyuan t2v loras
|
||||
--lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras
|
||||
--lora-dir-ltxv path # Path to LTX Video loras
|
||||
--lora-dir-flux path # Path to Flux loras
|
||||
--lora-preset preset # Load preset on startup
|
||||
--check-loras # Filter incompatible loras
|
||||
```
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations.
|
||||
|
||||
Most models can combined with Loras Accelerators (check the Lora guide) to accelerate the generation of a video x2 or x3 with little quality loss
|
||||
|
||||
|
||||
## Wan 2.1 Text2Video Models
|
||||
Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images.
|
||||
@ -65,6 +67,12 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
|
||||
|
||||
## Wan 2.1 Specialized Models
|
||||
|
||||
#### Multitalk
|
||||
- **Type**: Multi Talking head animation
|
||||
- **Input**: Voice track + image
|
||||
- **Works on**: People
|
||||
- **Use case**: Lip-sync and voice-driven animation for up to two people
|
||||
|
||||
#### FantasySpeaking
|
||||
- **Type**: Talking head animation
|
||||
- **Input**: Voice track + image
|
||||
@ -82,7 +90,7 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
|
||||
- **Requirements**: 81+ frame input videos, 15+ denoising steps
|
||||
- **Use case**: View same scene from different angles
|
||||
|
||||
#### Sky Reels v2
|
||||
#### Sky Reels v2 Diffusion
|
||||
- **Type**: Diffusion Forcing model
|
||||
- **Specialty**: "Infinite length" videos
|
||||
- **Features**: High quality continuous generation
|
||||
@ -107,22 +115,6 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
|
||||
|
||||
<BR>
|
||||
|
||||
## Wan Special Loras
|
||||
### Safe-Forcing lightx2v Lora
|
||||
- **Type**: Distilled model (Lora implementation)
|
||||
- **Speed**: 4-8 steps generation, 2x faster (no classifier free guidance)
|
||||
- **Compatible**: Works with t2v and i2v Wan 14B models
|
||||
- **Setup**: Requires Safe-Forcing lightx2v Lora (see [LORAS.md](LORAS.md))
|
||||
|
||||
|
||||
### Causvid Lora
|
||||
- **Type**: Distilled model (Lora implementation)
|
||||
- **Speed**: 4-12 steps generation, 2x faster (no classifier free guidance)
|
||||
- **Compatible**: Works with Wan 14B models
|
||||
- **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md))
|
||||
|
||||
|
||||
<BR>
|
||||
|
||||
## Hunyuan Video Models
|
||||
|
||||
|
||||
0
finetunes/put your finetunes here.txt
Normal file
0
finetunes/put your finetunes here.txt
Normal file
@ -65,7 +65,7 @@ class model_factory:
|
||||
fit_into_canvas = None,
|
||||
callback = None,
|
||||
loras_slists = None,
|
||||
frame_num = 1,
|
||||
batch_size = 1,
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -89,7 +89,7 @@ class model_factory:
|
||||
img_cond=image_ref,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=frame_num,
|
||||
bs=batch_size,
|
||||
seed=seed,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
21
postprocessing/film_grain.py
Normal file
21
postprocessing/film_grain.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Thanks to https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/film_grain.py
|
||||
import torch
|
||||
|
||||
def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation: float = 0.5):
|
||||
device = images.device
|
||||
|
||||
images = images.permute(1, 2 ,3 ,0)
|
||||
images.add_(1.).div_(2.)
|
||||
grain = torch.randn_like(images, device=device)
|
||||
grain[:, :, :, 0] *= 2
|
||||
grain[:, :, :, 2] *= 3
|
||||
grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat(
|
||||
1, 1, 1, 3
|
||||
) * (1 - saturation)
|
||||
|
||||
# Blend the grain with the image
|
||||
noised_images = images + grain_intensity * grain
|
||||
noised_images.clamp_(0, 1)
|
||||
noised_images.sub_(.5).mul_(2.)
|
||||
noised_images = noised_images.permute(3, 0, 1 ,2)
|
||||
return noised_images
|
||||
@ -65,6 +65,7 @@ def get_frames_from_image(image_input, image_state):
|
||||
Return
|
||||
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||
"""
|
||||
load_sam()
|
||||
|
||||
user_name = time.time()
|
||||
frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
|
||||
@ -89,7 +90,7 @@ def get_frames_from_image(image_input, image_state):
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True),\
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=False), \
|
||||
gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=True), \
|
||||
gr.update(visible=True)
|
||||
|
||||
@ -103,6 +104,8 @@ def get_frames_from_video(video_input, video_state):
|
||||
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||
"""
|
||||
|
||||
load_sam()
|
||||
|
||||
while model == None:
|
||||
time.sleep(1)
|
||||
|
||||
@ -273,6 +276,20 @@ def save_video(frames, output_path, fps):
|
||||
|
||||
return output_path
|
||||
|
||||
def mask_to_xyxy_box(mask):
|
||||
rows, cols = np.where(mask == 255)
|
||||
xmin = min(cols)
|
||||
xmax = max(cols) + 1
|
||||
ymin = min(rows)
|
||||
ymax = max(rows) + 1
|
||||
xmin = max(xmin, 0)
|
||||
ymin = max(ymin, 0)
|
||||
xmax = min(xmax, mask.shape[1])
|
||||
ymax = min(ymax, mask.shape[0])
|
||||
box = [xmin, ymin, xmax, ymax]
|
||||
box = [int(x) for x in box]
|
||||
return box
|
||||
|
||||
# image matting
|
||||
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
|
||||
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
||||
@ -320,9 +337,17 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
|
||||
foreground = output_frames
|
||||
|
||||
foreground_output = Image.fromarray(foreground[-1])
|
||||
alpha_output = Image.fromarray(alpha[-1][:,:,0])
|
||||
|
||||
return foreground_output, gr.update(visible=True)
|
||||
alpha_output = alpha[-1][:,:,0]
|
||||
frame_temp = alpha_output.copy()
|
||||
alpha_output[frame_temp > 127] = 0
|
||||
alpha_output[frame_temp <= 127] = 255
|
||||
bbox_info = mask_to_xyxy_box(alpha_output)
|
||||
h = alpha_output.shape[0]
|
||||
w = alpha_output.shape[1]
|
||||
bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ]
|
||||
bbox_info = ":".join(bbox_info)
|
||||
alpha_output = Image.fromarray(alpha_output)
|
||||
return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True)
|
||||
|
||||
# video matting
|
||||
def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
|
||||
@ -469,6 +494,13 @@ def restart():
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
|
||||
|
||||
def load_sam():
|
||||
global model_loaded
|
||||
global model
|
||||
global matanyone_model
|
||||
model.samcontroler.sam_controler.model.to(arg_device)
|
||||
matanyone_model.to(arg_device)
|
||||
|
||||
def load_unload_models(selected):
|
||||
global model_loaded
|
||||
global model
|
||||
@ -476,8 +508,7 @@ def load_unload_models(selected):
|
||||
if selected:
|
||||
# print("Matanyone Tab Selected")
|
||||
if model_loaded:
|
||||
model.samcontroler.sam_controler.model.to(arg_device)
|
||||
matanyone_model.to(arg_device)
|
||||
load_sam()
|
||||
else:
|
||||
# args, defined in track_anything.py
|
||||
sam_checkpoint_url_dict = {
|
||||
@ -522,12 +553,16 @@ def export_to_vace_video_input(foreground_video_output):
|
||||
|
||||
def export_image(image_refs, image_output):
|
||||
gr.Info("Masked Image transferred to Current Video")
|
||||
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
||||
if image_refs == None:
|
||||
image_refs =[]
|
||||
image_refs.append( image_output)
|
||||
return image_refs
|
||||
|
||||
def export_image_mask(image_input, image_mask):
|
||||
gr.Info("Input Image & Mask transferred to Current Video")
|
||||
return Image.fromarray(image_input), image_mask
|
||||
|
||||
|
||||
def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
|
||||
gr.Info("Original Video and Full Mask have been transferred")
|
||||
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
||||
@ -543,7 +578,7 @@ def teleport_to_video_tab(tab_state):
|
||||
return gr.Tabs(selected="video_gen")
|
||||
|
||||
|
||||
def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, vace_image_refs):
|
||||
def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
|
||||
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
||||
|
||||
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
|
||||
@ -677,7 +712,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
|
||||
with gr.Column(scale=2):
|
||||
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
|
||||
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
||||
export_image_mask_btn = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
||||
with gr.Row():
|
||||
with gr.Row(visible= False):
|
||||
export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False)
|
||||
@ -696,7 +731,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
],
|
||||
outputs=[video_state, video_info, template_frame,
|
||||
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame,
|
||||
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
|
||||
foreground_video_output, alpha_video_output, foreground_output_button, export_image_mask_btn, mask_dropdown, step2_title]
|
||||
)
|
||||
|
||||
# second step: select images from slider
|
||||
@ -755,7 +790,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
foreground_video_output, alpha_video_output,
|
||||
template_frame,
|
||||
image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
|
||||
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
|
||||
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
|
||||
],
|
||||
queue=False,
|
||||
show_progress=False)
|
||||
@ -770,7 +805,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
foreground_video_output, alpha_video_output,
|
||||
template_frame,
|
||||
image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
|
||||
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
|
||||
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
|
||||
],
|
||||
queue=False,
|
||||
show_progress=False)
|
||||
@ -872,15 +907,19 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
# output image
|
||||
with gr.Row(equal_height=True):
|
||||
foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image")
|
||||
alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image")
|
||||
with gr.Row(equal_height=True):
|
||||
bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False)
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button")
|
||||
with gr.Column(scale=2, visible= False):
|
||||
alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image")
|
||||
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
||||
# with gr.Row():
|
||||
export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button")
|
||||
# with gr.Column(scale=2, visible= True):
|
||||
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
|
||||
|
||||
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
|
||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
|
||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||
|
||||
# first step: get the image information
|
||||
extract_frames_button.click(
|
||||
@ -890,9 +929,17 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
],
|
||||
outputs=[image_state, image_info, template_frame,
|
||||
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
|
||||
foreground_image_output, alpha_image_output, export_image_btn, alpha_output_button, mask_dropdown, step2_title]
|
||||
foreground_image_output, alpha_image_output, bbox_info, export_image_btn, export_image_mask_btn, mask_dropdown, step2_title]
|
||||
)
|
||||
|
||||
# points clear
|
||||
clear_button_click.click(
|
||||
fn = clear_click,
|
||||
inputs = [image_state, click_state,],
|
||||
outputs = [template_frame,click_state],
|
||||
)
|
||||
|
||||
|
||||
# second step: select images from slider
|
||||
image_selection_slider.release(fn=select_image_template,
|
||||
inputs=[image_selection_slider, image_state, interactive_state],
|
||||
@ -925,7 +972,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
|
||||
matting_button.click(
|
||||
fn=image_matting,
|
||||
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
|
||||
outputs=[foreground_image_output, export_image_btn]
|
||||
outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
|
||||
)
|
||||
|
||||
|
||||
|
||||
106
wan/any2video.py
106
wan/any2video.py
@ -61,6 +61,7 @@ class WanAny2V:
|
||||
checkpoint_dir,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
model_def = None,
|
||||
base_model_type = None,
|
||||
text_encoder_filename = None,
|
||||
quantizeTransformer = False,
|
||||
@ -75,7 +76,8 @@ class WanAny2V:
|
||||
self.dtype = dtype
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
self.model_def = model_def
|
||||
self.image_outputs = model_def.get("image_outputs", False)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
@ -106,7 +108,7 @@ class WanAny2V:
|
||||
# config = json.load(f)
|
||||
# from mmgp import safetensors2
|
||||
# sd = safetensors2.torch_load_file(xmodel_filename)
|
||||
# model_filename = "c:/temp/fflf/diffusion_pytorch_model-00001-of-00007.safetensors"
|
||||
# model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors"
|
||||
base_config_file = f"configs/{base_model_type}.json"
|
||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
||||
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
|
||||
@ -208,7 +210,7 @@ class WanAny2V:
|
||||
|
||||
if refs is not None:
|
||||
length = len(refs)
|
||||
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
||||
mask_pad = torch.zeros(mask.shape[0], length, *mask.shape[-2:], dtype=mask.dtype, device=mask.device)
|
||||
mask = torch.cat((mask_pad, mask), dim=1)
|
||||
result_masks.append(mask)
|
||||
return result_masks
|
||||
@ -327,20 +329,6 @@ class WanAny2V:
|
||||
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(zs)
|
||||
# else:
|
||||
# assert len(zs) == len(ref_images)
|
||||
|
||||
trimed_zs = []
|
||||
for z, refs in zip(zs, ref_images):
|
||||
if refs is not None:
|
||||
z = z[:, len(refs):, :, :]
|
||||
trimed_zs.append(z)
|
||||
|
||||
return self.vae.decode(trimed_zs, tile_size= tile_size)
|
||||
|
||||
def get_vae_latents(self, ref_images, device, tile_size= 0):
|
||||
ref_vae_latents = []
|
||||
for ref_image in ref_images:
|
||||
@ -366,6 +354,7 @@ class WanAny2V:
|
||||
height = 720,
|
||||
fit_into_canvas = True,
|
||||
frame_num=81,
|
||||
batch_size = 1,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
@ -397,6 +386,7 @@ class WanAny2V:
|
||||
NAG_alpha = 0.5,
|
||||
offloadobj = None,
|
||||
apg_switch = False,
|
||||
speakers_bboxes = None,
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -554,8 +544,8 @@ class WanAny2V:
|
||||
overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4)
|
||||
if overlapped_latents != None:
|
||||
# disabled because looks worse
|
||||
if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
|
||||
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone()
|
||||
if False and overlapped_latents_frames_num > 1: lat_y[:, :, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
|
||||
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
|
||||
y = torch.concat([msk, lat_y])
|
||||
lat_y = None
|
||||
kwargs.update({'clip_fea': clip_context, 'y': y})
|
||||
@ -586,7 +576,7 @@ class WanAny2V:
|
||||
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
|
||||
else:
|
||||
overlapped_latents_frames_num = overlapped_frames_num = 0
|
||||
if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
|
||||
if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
|
||||
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
|
||||
latent_keep_frames = []
|
||||
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
|
||||
@ -609,6 +599,7 @@ class WanAny2V:
|
||||
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
||||
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
|
||||
trim_frames = input_ref_images.shape[1]
|
||||
|
||||
# Vace
|
||||
if vace :
|
||||
@ -633,8 +624,8 @@ class WanAny2V:
|
||||
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
|
||||
if overlapped_latents != None :
|
||||
overlapped_latents_size = overlapped_latents.shape[1]
|
||||
extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone()
|
||||
overlapped_latents_size = overlapped_latents.shape[2]
|
||||
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
@ -649,7 +640,7 @@ class WanAny2V:
|
||||
from wan.multitalk.multitalk import get_target_masks
|
||||
audio_proj = [audio.to(self.dtype) for audio in audio_proj]
|
||||
human_no = len(audio_proj[0])
|
||||
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None
|
||||
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None
|
||||
|
||||
if fantasy and audio_proj != None:
|
||||
kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, })
|
||||
@ -658,8 +649,8 @@ class WanAny2V:
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
expand_shape = [batch_size] + [-1] * len(target_shape)
|
||||
# Ropes
|
||||
batch_size = 1
|
||||
if target_camera != None:
|
||||
shape = list(target_shape[1:])
|
||||
shape[0] *= 2
|
||||
@ -698,14 +689,14 @@ class WanAny2V:
|
||||
|
||||
if sample_scheduler != None:
|
||||
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
|
||||
|
||||
latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||
# b, c, lat_f, lat_h, lat_w
|
||||
latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||
if apg_switch != 0:
|
||||
apg_momentum = -0.75
|
||||
apg_norm_threshold = 55
|
||||
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
|
||||
# self.image_outputs = False
|
||||
# denoising
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
@ -715,36 +706,36 @@ class WanAny2V:
|
||||
|
||||
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
|
||||
sigma = t / 1000
|
||||
noise = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||
if inject_from_start:
|
||||
new_latents = latents.clone()
|
||||
new_latents[:, :source_latents.shape[1] ] = noise[:, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents
|
||||
new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0)
|
||||
for latent_no, keep_latent in enumerate(latent_keep_frames):
|
||||
if not keep_latent:
|
||||
new_latents[:, latent_no:latent_no+1 ] = latents[:, latent_no:latent_no+1]
|
||||
new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
|
||||
latents = new_latents
|
||||
new_latents = None
|
||||
else:
|
||||
latents = noise * sigma + (1 - sigma) * source_latents
|
||||
latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0)
|
||||
noise = None
|
||||
|
||||
if extended_overlapped_latents != None:
|
||||
latent_noise_factor = t / 1000
|
||||
latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
|
||||
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
|
||||
if vace:
|
||||
overlap_noise_factor = overlap_noise / 1000
|
||||
for zz in z:
|
||||
zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor
|
||||
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
|
||||
|
||||
if target_camera != None:
|
||||
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
||||
latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!!
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
if phantom:
|
||||
gen_args = {
|
||||
"x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 +
|
||||
[ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]),
|
||||
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
||||
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
|
||||
"context": [context, context_null, context_null] ,
|
||||
}
|
||||
elif fantasy:
|
||||
@ -832,38 +823,41 @@ class WanAny2V:
|
||||
if sample_solver == "euler":
|
||||
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
|
||||
dt = dt / self.num_timesteps
|
||||
latents = latents - noise_pred * dt[:, None, None, None]
|
||||
latents = latents - noise_pred * dt[:, None, None, None, None]
|
||||
else:
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
||||
latents = sample_scheduler.step(
|
||||
noise_pred[:, :, :target_shape[1]],
|
||||
t,
|
||||
latents.unsqueeze(0),
|
||||
latents,
|
||||
**scheduler_kwargs)[0]
|
||||
latents = temp_x0.squeeze(0)
|
||||
del temp_x0
|
||||
|
||||
if callback is not None:
|
||||
callback(i, latents, False)
|
||||
latents_preview = latents
|
||||
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
|
||||
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
|
||||
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
|
||||
callback(i, latents_preview[0], False)
|
||||
latents_preview = None
|
||||
|
||||
x0 = [latents]
|
||||
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
|
||||
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
|
||||
if return_latent_slice != None:
|
||||
latent_slice = latents[:, :, return_latent_slice].clone()
|
||||
|
||||
x0 =latents.unbind(dim=0)
|
||||
|
||||
if chipmunk:
|
||||
self.model.release_chipmunk() # need to add it at every exit when in prod
|
||||
|
||||
if return_latent_slice != None:
|
||||
latent_slice = latents[:, return_latent_slice].clone()
|
||||
if vace:
|
||||
# vace post processing
|
||||
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
||||
else:
|
||||
if phantom and input_ref_images != None:
|
||||
trim_frames = input_ref_images.shape[1]
|
||||
if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0]
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
|
||||
if self.image_outputs:
|
||||
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
|
||||
else:
|
||||
videos = videos[0] # return only first video
|
||||
if return_latent_slice != None:
|
||||
return { "x" : videos[0], "latent_slice" : latent_slice }
|
||||
return videos[0]
|
||||
return { "x" : videos, "latent_slice" : latent_slice }
|
||||
return videos
|
||||
|
||||
def adapt_vace_model(self):
|
||||
model = self.model
|
||||
|
||||
@ -31,6 +31,7 @@ class DTT2V:
|
||||
rank=0,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
model_def = None,
|
||||
base_model_type = None,
|
||||
save_quantized = False,
|
||||
text_encoder_filename = None,
|
||||
@ -53,6 +54,8 @@ class DTT2V:
|
||||
checkpoint_path=text_encoder_filename,
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||
shard_fn= None)
|
||||
self.model_def = model_def
|
||||
self.image_outputs = model_def.get("image_outputs", False)
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
@ -202,6 +205,7 @@ class DTT2V:
|
||||
width: int = 832,
|
||||
fit_into_canvas = True,
|
||||
frame_num: int = 97,
|
||||
batch_size = 1,
|
||||
sampling_steps: int = 50,
|
||||
shift: float = 1.0,
|
||||
guide_scale: float = 5.0,
|
||||
@ -224,8 +228,9 @@ class DTT2V:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(seed)
|
||||
self._guidance_scale = guide_scale
|
||||
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
|
||||
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
|
||||
if frame_num > 1:
|
||||
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
|
||||
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
|
||||
|
||||
if ar_step == 0:
|
||||
causal_block_size = 1
|
||||
@ -297,12 +302,12 @@ class DTT2V:
|
||||
prefix_video = prefix_video[:, : predix_video_latent_length]
|
||||
|
||||
base_num_frames_iter = latent_length
|
||||
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
||||
latent_shape = [batch_size, 16, base_num_frames_iter, latent_height, latent_width]
|
||||
latents = self.prepare_latents(
|
||||
latent_shape, dtype=torch.float32, device=self.device, generator=generator
|
||||
)
|
||||
if prefix_video is not None:
|
||||
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
|
||||
latents[:, :, :predix_video_latent_length] = prefix_video.to(torch.float32)
|
||||
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
||||
base_num_frames_iter,
|
||||
init_timesteps,
|
||||
@ -340,7 +345,7 @@ class DTT2V:
|
||||
else:
|
||||
self.model.enable_cache = None
|
||||
from mmgp import offload
|
||||
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
||||
freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False)
|
||||
kwrags = {
|
||||
"freqs" :freqs,
|
||||
"fps" : fps_embeds,
|
||||
@ -358,15 +363,15 @@ class DTT2V:
|
||||
update_mask_i = step_update_mask[i]
|
||||
valid_interval_start, valid_interval_end = valid_interval[i]
|
||||
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
||||
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
|
||||
latent_model_input = latents[:, :, valid_interval_start:valid_interval_end, :, :].clone()
|
||||
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
|
||||
noise_factor = 0.001 * overlap_noise
|
||||
timestep_for_noised_condition = overlap_noise
|
||||
latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
|
||||
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
||||
latent_model_input[:, :, valid_interval_start:predix_video_latent_length] = (
|
||||
latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
|
||||
* (1.0 - noise_factor)
|
||||
+ torch.randn_like(
|
||||
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
||||
latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
|
||||
)
|
||||
* noise_factor
|
||||
)
|
||||
@ -417,18 +422,27 @@ class DTT2V:
|
||||
del noise_pred_cond, noise_pred_uncond
|
||||
for idx in range(valid_interval_start, valid_interval_end):
|
||||
if update_mask_i[idx].item():
|
||||
latents[:, idx] = sample_schedulers[idx].step(
|
||||
noise_pred[:, idx - valid_interval_start],
|
||||
latents[:, :, idx] = sample_schedulers[idx].step(
|
||||
noise_pred[:, :, idx - valid_interval_start],
|
||||
timestep_i[idx],
|
||||
latents[:, idx],
|
||||
latents[:, :, idx],
|
||||
return_dict=False,
|
||||
generator=generator,
|
||||
)[0]
|
||||
sample_schedulers_counter[idx] += 1
|
||||
if callback is not None:
|
||||
callback(i, latents.squeeze(0), False)
|
||||
latents_preview = latents
|
||||
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
|
||||
callback(i, latents_preview[0], False)
|
||||
latents_preview = None
|
||||
|
||||
x0 = latents.unsqueeze(0)
|
||||
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
|
||||
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
||||
return output_video
|
||||
x0 =latents.unbind(dim=0)
|
||||
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
|
||||
if self.image_outputs:
|
||||
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
|
||||
else:
|
||||
videos = videos[0] # return only first video
|
||||
|
||||
return videos
|
||||
|
||||
@ -185,7 +185,7 @@ def pay_attention(
|
||||
q,k,v = qkv_list
|
||||
qkv_list.clear()
|
||||
out_dtype = q.dtype
|
||||
if q.dtype == torch.bfloat16 and not bfloat16_supported:
|
||||
if q.dtype == torch.bfloat16 and not bfloat16_supported:
|
||||
q = q.to(torch.float16)
|
||||
k = k.to(torch.float16)
|
||||
v = v.to(torch.float16)
|
||||
@ -194,7 +194,9 @@ def pay_attention(
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
batch = len(q)
|
||||
if len(k) != batch: k = k.expand(batch, -1, -1, -1)
|
||||
if len(v) != batch: v = v.expand(batch, -1, -1, -1)
|
||||
if attn == "chipmunk":
|
||||
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
|
||||
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
|
||||
|
||||
@ -33,9 +33,10 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
|
||||
|
||||
def reshape_latent(latent, latent_frames):
|
||||
if latent_frames == latent.shape[0]:
|
||||
return latent
|
||||
return latent.reshape(latent_frames, -1, latent.shape[-1] )
|
||||
return latent.reshape(latent.shape[0], latent_frames, -1, latent.shape[-1] )
|
||||
|
||||
def restore_latent_shape(latent):
|
||||
return latent.reshape(latent.shape[0], -1, latent.shape[-1] )
|
||||
|
||||
|
||||
def identify_k( b: float, d: int, N: int):
|
||||
@ -493,7 +494,7 @@ class WanAttentionBlock(nn.Module):
|
||||
x_mod = reshape_latent(x_mod , latent_frames)
|
||||
x_mod *= 1 + e[1]
|
||||
x_mod += e[0]
|
||||
x_mod = reshape_latent(x_mod , 1)
|
||||
x_mod = restore_latent_shape(x_mod)
|
||||
if cam_emb != None:
|
||||
cam_emb = self.cam_encoder(cam_emb)
|
||||
cam_emb = cam_emb.repeat(1, 2, 1)
|
||||
@ -510,7 +511,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
||||
x.addcmul_(y, e[2])
|
||||
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
||||
x, y = restore_latent_shape(x), restore_latent_shape(y)
|
||||
del y
|
||||
y = self.norm3(x)
|
||||
y = y.to(attention_dtype)
|
||||
@ -542,7 +543,7 @@ class WanAttentionBlock(nn.Module):
|
||||
y = reshape_latent(y , latent_frames)
|
||||
y *= 1 + e[4]
|
||||
y += e[3]
|
||||
y = reshape_latent(y , 1)
|
||||
y = restore_latent_shape(y)
|
||||
y = y.to(attention_dtype)
|
||||
|
||||
ffn = self.ffn[0]
|
||||
@ -562,7 +563,7 @@ class WanAttentionBlock(nn.Module):
|
||||
y = y.to(dtype)
|
||||
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
||||
x.addcmul_(y, e[5])
|
||||
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
||||
x, y = restore_latent_shape(x), restore_latent_shape(y)
|
||||
|
||||
if hints_processed is not None:
|
||||
for hint, scale in zip(hints_processed, context_scale):
|
||||
@ -669,6 +670,8 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
hints[0] = None
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c)
|
||||
bz = x.shape[0]
|
||||
if bz > c.shape[0]: c = c.repeat(bz, 1, 1 )
|
||||
c += x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
@ -707,7 +710,7 @@ class Head(nn.Module):
|
||||
x = reshape_latent(x , latent_frames)
|
||||
x *= (1 + e[1])
|
||||
x += e[0]
|
||||
x = reshape_latent(x , 1)
|
||||
x = restore_latent_shape(x)
|
||||
x= x.to(self.head.weight.dtype)
|
||||
x = self.head(x)
|
||||
return x
|
||||
@ -1163,10 +1166,14 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
last_x_idx = i
|
||||
else:
|
||||
# image source
|
||||
bz = len(x)
|
||||
if y is not None:
|
||||
x = torch.cat([x, y], dim=0)
|
||||
y = y.unsqueeze(0)
|
||||
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
|
||||
x = torch.cat([x, y], dim=1)
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
|
||||
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
|
||||
x = self.patch_embedding(x).to(modulation_dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
if chipmunk:
|
||||
x = x.unsqueeze(-1)
|
||||
@ -1204,7 +1211,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
) # b, dim
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||
|
||||
if self.inject_sample_info:
|
||||
if self.inject_sample_info and fps!=None:
|
||||
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
||||
|
||||
fps_emb = self.fps_embedding(fps).to(e.dtype)
|
||||
@ -1402,7 +1409,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
x_list[i] = self.unpatchify(x, grid_sizes)
|
||||
del x
|
||||
|
||||
return [x[0].float() for x in x_list]
|
||||
return [x.float() for x in x_list]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
@ -1427,7 +1434,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
||||
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
out.append(u)
|
||||
return out
|
||||
if len(x) == 1:
|
||||
return out[0].unsqueeze(0)
|
||||
else:
|
||||
return torch.stack(out, 0)
|
||||
|
||||
def init_weights(self):
|
||||
r"""
|
||||
|
||||
@ -333,7 +333,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
|
||||
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
|
||||
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
|
||||
@ -351,7 +351,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
|
||||
if self.qk_norm:
|
||||
encoder_k = self.add_k_norm(encoder_k)
|
||||
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
|
||||
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
|
||||
|
||||
@ -272,6 +272,34 @@ def timestep_transform(
|
||||
new_t = new_t * num_timesteps
|
||||
return new_t
|
||||
|
||||
def parse_speakers_locations(speakers_locations):
|
||||
bbox = {}
|
||||
if speakers_locations is None or len(speakers_locations) == 0:
|
||||
return None, ""
|
||||
speakers = speakers_locations.split(" ")
|
||||
if len(speakers) !=2:
|
||||
error= "Two speakers locations should be defined"
|
||||
return "", error
|
||||
|
||||
for i, speaker in enumerate(speakers):
|
||||
location = speaker.strip().split(":")
|
||||
if len(location) not in (2,4):
|
||||
error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom"
|
||||
return "", error
|
||||
try:
|
||||
good = False
|
||||
location_float = [ float(val) for val in location]
|
||||
good = all( 0 <= val <= 100 for val in location_float)
|
||||
except:
|
||||
pass
|
||||
if not good:
|
||||
error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100."
|
||||
return "", error
|
||||
if len(location_float) == 2:
|
||||
location_float = [location_float[0], 0, location_float[1], 100]
|
||||
bbox[f"human{i}"] = location_float
|
||||
return bbox, ""
|
||||
|
||||
|
||||
# construct human mask
|
||||
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
|
||||
@ -286,7 +314,9 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
|
||||
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
|
||||
background_mask = torch.zeros([src_h, src_w])
|
||||
for _, person_bbox in bbox.items():
|
||||
x_min, y_min, x_max, y_max = person_bbox
|
||||
y_min, x_min, y_max, x_max = person_bbox
|
||||
x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95)
|
||||
x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100)
|
||||
human_mask = torch.zeros([src_h, src_w])
|
||||
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
|
||||
background_mask += human_mask
|
||||
@ -306,7 +336,7 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
|
||||
human_masks = [human_mask1, human_mask2]
|
||||
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
|
||||
human_masks.append(background_mask)
|
||||
|
||||
# toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy())
|
||||
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
|
||||
# resize and centercrop for ref_target_masks
|
||||
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))
|
||||
|
||||
@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
|
||||
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ import os
|
||||
import os.path as osp
|
||||
import torchvision.transforms.functional as TF
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cv2
|
||||
import tempfile
|
||||
import imageio
|
||||
import torch
|
||||
import decord
|
||||
@ -101,6 +102,29 @@ def get_video_frame(file_name, frame_no):
|
||||
img = Image.fromarray(frame.numpy().astype(np.uint8))
|
||||
return img
|
||||
|
||||
def convert_image_to_video(image):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
# Convert PIL/numpy image to OpenCV format if needed
|
||||
if isinstance(image, np.ndarray):
|
||||
# Gradio images are typically RGB, OpenCV expects BGR
|
||||
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
# Handle PIL Image
|
||||
img_array = np.array(image)
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
|
||||
height, width = img_bgr.shape[:2]
|
||||
|
||||
# Create temporary video file (auto-cleaned by Gradio)
|
||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height))
|
||||
out.write(img_bgr)
|
||||
out.release()
|
||||
return temp_video.name
|
||||
|
||||
def resize_lanczos(img, h, w):
|
||||
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user