Merge remote-tracking branch 'upstream/main'
2
.gitignore
vendored
@ -14,7 +14,7 @@
|
||||
*.pth
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
*.json
|
||||
#*.json
|
||||
# *.txt
|
||||
*.backup
|
||||
*.pkl
|
||||
|
||||
53
LICENSE.txt
@ -1,17 +1,46 @@
|
||||
FREE for Non Commercial USE
|
||||
WanGP NON-COMMERCIAL EVALUATION LICENSE 1.0
|
||||
|
||||
You are free to:
|
||||
- Share — copy and redistribute the material in any medium or format
|
||||
- Adapt — remix, transform, and build upon the material
|
||||
The licensor cannot revoke these freedoms as long as you follow the license terms.
|
||||
Definitions
|
||||
1.1 “Software” means the source code, binaries, libraries, utilities and UI released under this license.
|
||||
1.2 “Output” means images, videos or other media produced by running the Software.
|
||||
1.3 “Commercial Use” means:
|
||||
a) selling, sublicensing, renting, leasing, or otherwise distributing the Software, in whole or in part, for a fee or other consideration; or
|
||||
b) offering the Software (or any derivative) as part of a paid product or hosted service; or
|
||||
c) using the Software (or any derivative) to provide cloud-based or backend services, where end users access or pay for those services.
|
||||
|
||||
Under the following terms:
|
||||
- Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
||||
NonCommercial — You may not use the material for commercial purposes .
|
||||
License Grant
|
||||
Subject to Section 3:
|
||||
a) You are granted a worldwide, non-exclusive, royalty-free, revocable license to use, reproduce, modify and distribute the Software for non-commercial purposes only.
|
||||
b) You are granted a worldwide, non-exclusive, royalty-free, irrevocable license to use, reproduce, modify and distribute the Output for any purpose, including commercial sale, provided that any commercial distribution of the Output includes a clear notice that the Output was produced (in whole or in part) using WanGP, along with a hyperlink to the WanGP application’s About tab or repository.
|
||||
|
||||
- No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
|
||||
Notices:
|
||||
Restrictions
|
||||
3.1 You MAY NOT distribute, sublicense or otherwise make available the Software (or any derivative) for Commercial Use.
|
||||
3.2 You MAY sell, license or otherwise commercially exploit the Output without restriction.
|
||||
3.3 If you wish to use the Software for Commercial Use, you must obtain a separate commercial license from the Licensor.
|
||||
|
||||
- You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
|
||||
Third-Party Components 4.1 The Software includes components licensed under various open-source licenses (e.g., Apache 2.0, MIT, BSD). 4.2 You must comply with all applicable terms of those third-party licenses, including preservation of copyright notices, inclusion of required license texts, and patent-grant provisions. 4.3 You can find the full text of each third-party license via the “About” tab in the WanGP application, which provides links to their original GitHub repositories.
|
||||
|
||||
Attribution
|
||||
5.1 You must give appropriate credit by including:
|
||||
• a copy of this license (or a link to it), and
|
||||
• a notice that your use is based on “WanGP”.
|
||||
5.2 You may do so in any reasonable manner, but not in any way that suggests the Licensor endorses you or your use.
|
||||
|
||||
Disclaimer of Warranty & Liability
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE.
|
||||
|
||||
Commercial Licensing The Licensor may offer commercial licenses for the Software, which grant rights to use the Software for Commercial Use. Please contact [deepbeepmeep@yahoo.com] for terms and pricing.
|
||||
|
||||
Effective Date & Previous Versions
|
||||
8.1 This license is effective as of the date the LICENSE file is updated in the WanGP repository.
|
||||
8.2 Any copies of the Software obtained under prior license terms before this Effective Date remain governed by those prior terms; such granted rights are irrevocable.
|
||||
8.3 Use of the Software after the release of any subsequent version by the Licensor is subject to the terms of the then-current license, unless a separate agreement is in place.
|
||||
|
||||
Acceptable Use / Moral Clause
|
||||
9.1 You MAY NOT use the Software or the Output to facilitate or produce content that is illegal, harmful, violent, harassing, defamatory, fraudulent, or otherwise violates applicable laws or fundamental human rights.
|
||||
9.2 You MAY NOT deploy the Software or Output in contexts that promote hate speech, extremist ideology, human rights abuses, or other actions that could foreseeably cause significant harm to individuals or groups.
|
||||
9.3 The Licensor reserves the right to terminate the rights granted under this license if a licensee materially breaches this Acceptable Use clause.
|
||||
|
||||
END OF LICENSE
|
||||
|
||||
No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
|
||||
55
README.md
@ -20,6 +20,59 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates :
|
||||
|
||||
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
|
||||
|
||||
This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
|
||||
|
||||
Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
|
||||
|
||||
Support:
|
||||
- Video: x264, x264 lossless, x265
|
||||
- Images: jpeg, png, webp, wbp lossless
|
||||
Generation Settings are stored in each of the above regardless of the format (that was the hard part).
|
||||
|
||||
Also you can now choose different output directories for images and videos.
|
||||
|
||||
unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
|
||||
*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
|
||||
*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
|
||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
|
||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
|
||||
|
||||
|
||||
### August 10 2025: WanGP v7.76 - Faster than the VAE ...
|
||||
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
|
||||
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
|
||||
|
||||
*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
|
||||
|
||||
### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
|
||||
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
|
||||
|
||||
Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
|
||||
|
||||
### August 8 2025: WanGP v7.73 - Qwen Rebirth
|
||||
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
|
||||
|
||||
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
|
||||
|
||||
Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
|
||||
|
||||
*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
|
||||
|
||||
|
||||
### August 6 2025: WanGP v7.71 - Picky, picky
|
||||
|
||||
This release comes with two new models :
|
||||
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
|
||||
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
|
||||
|
||||
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
|
||||
|
||||
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
|
||||
|
||||
|
||||
### August 4 2025: WanGP v7.6 - Remuxed
|
||||
|
||||
With this new version you won't have any excuse if there is no sound in your video.
|
||||
@ -180,7 +233,7 @@ git clone https://github.com/deepbeepmeep/Wan2GP.git
|
||||
cd Wan2GP
|
||||
conda create -n wan2gp python=3.10.9
|
||||
conda activate wan2gp
|
||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 1.7 MiB |
|
Before Width: | Height: | Size: 516 KiB |
|
Before Width: | Height: | Size: 871 KiB |
BIN
assets/logo.png
|
Before Width: | Height: | Size: 55 KiB |
|
Before Width: | Height: | Size: 294 KiB |
|
Before Width: | Height: | Size: 1.5 MiB |
|
Before Width: | Height: | Size: 628 KiB |
|
Before Width: | Height: | Size: 208 KiB |
15
configs/i2v_2_2_multitalk.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"_class_name": "WanModel",
|
||||
"_diffusers_version": "0.33.0",
|
||||
"dim": 5120,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_dim": 36,
|
||||
"model_type": "i2v2_2",
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"multitalk_output_dim": 768
|
||||
}
|
||||
18
configs/qwen_image_20B.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"_class_name": "QwenImageTransformer2DModel",
|
||||
"_diffusers_version": "0.34.0.dev0",
|
||||
"attention_head_dim": 128,
|
||||
"axes_dims_rope": [
|
||||
16,
|
||||
56,
|
||||
56
|
||||
],
|
||||
"guidance_embeds": false,
|
||||
"in_channels": 64,
|
||||
"joint_attention_dim": 3584,
|
||||
"num_attention_heads": 24,
|
||||
"num_layers": 60,
|
||||
"out_channels": 16,
|
||||
"patch_size": 2,
|
||||
"pooled_projection_dim": 768
|
||||
}
|
||||
14
configs/ti2v_2_2.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"_class_name": "WanModel",
|
||||
"_diffusers_version": "0.33.0",
|
||||
"dim": 3072,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"in_dim": 48,
|
||||
"model_type": "ti2v2_2",
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"out_dim": 48,
|
||||
"text_len": 512
|
||||
}
|
||||
@ -5,8 +5,7 @@
|
||||
"architecture" : "fantasy",
|
||||
"modules": ["fantasy"],
|
||||
"description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.",
|
||||
"URLs": "i2v_720p",
|
||||
"teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||
"URLs": "i2v_720p"
|
||||
},
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
|
||||
18
defaults/i2v_2_2_multitalk.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.2 Multitalk 14B",
|
||||
"architecture" : "i2v_2_2_multitalk",
|
||||
"description": "The Multitalk module of Wan 2.1 has been combined with the Wan 2.2 image 2 video. It lets you have up to two people have a conversation.",
|
||||
"modules": ["multitalk"],
|
||||
"URLs": "i2v_2_2",
|
||||
"URLs2": "i2v_2_2",
|
||||
"group": "wan2_2",
|
||||
"visible": false
|
||||
},
|
||||
"switch_threshold" : 900,
|
||||
"guidance_scale" : 3.5,
|
||||
"guidance2_scale" : 3.5,
|
||||
"flow_shift" : 5
|
||||
|
||||
}
|
||||
@ -13,7 +13,7 @@
|
||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors"
|
||||
],
|
||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-dev.yaml"
|
||||
"LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml"
|
||||
},
|
||||
"num_inference_steps": 30
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"preload_URLs" : "ltxv_13B",
|
||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml"
|
||||
"LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml"
|
||||
},
|
||||
"num_inference_steps": 6
|
||||
}
|
||||
|
||||
21
defaults/qwen_image_20B.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Qwen Image 20B",
|
||||
"architecture": "qwen_image_20B",
|
||||
"description": "Qwen Image is generative model that will very high quality images. It is one of the few models capable to generate in the image very long texts.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"resolutions": [ ["1328x1328 (1:1)", "1328x1328"],
|
||||
["1664x928 (16:9)", "1664x928"],
|
||||
["928x1664 (9:16)", "928x1664"],
|
||||
["1472x1140 (4:3)", "1472x1140"],
|
||||
["1140x1472 (3:4)", "1140x1472"]],
|
||||
"attention": {"<89" : "sdpa"},
|
||||
"image_outputs": true
|
||||
},
|
||||
"prompt": "draw a hat",
|
||||
"resolution": "1280x720",
|
||||
"batch_size": 1
|
||||
}
|
||||
17
defaults/ti2v_2_2.json
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.2 TextImage2video 5B",
|
||||
"architecture": "ti2v_2_2",
|
||||
"description": "Wan 2.2 Text 2 Video model 5B",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_quanto_mbf16_int8.safetensors"
|
||||
],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"video_length": 121,
|
||||
"guidance_scale": 5,
|
||||
"flow_shift": 5,
|
||||
"num_inference_steps": 50,
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
15
defaults/ti2v_2_2_fastwan.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.2 FastWan TextImage2video 5B",
|
||||
"architecture": "ti2v_2_2",
|
||||
"description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution",
|
||||
"URLs": "ti2v_2_2",
|
||||
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"video_length": 121,
|
||||
"guidance_scale": 1,
|
||||
"flow_shift": 3,
|
||||
"num_inference_steps": 3,
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
@ -8,9 +8,9 @@ This guide covers installation for different GPU generations and operating syste
|
||||
- Conda or Python venv
|
||||
- Compatible GPU (RTX 10XX or newer recommended)
|
||||
|
||||
## Installation for RTX 10XX to RTX 40XX (Stable)
|
||||
## Installation for RTX 10XX to RTX 50XX (Stable)
|
||||
|
||||
This installation uses PyTorch 2.6.0 which is well-tested and stable.
|
||||
This installation uses PyTorch 2.7.0 which is well-tested and stable.
|
||||
|
||||
### Step 1: Download and Setup Environment
|
||||
|
||||
@ -27,8 +27,8 @@ conda activate wan2gp
|
||||
### Step 2: Install PyTorch
|
||||
|
||||
```shell
|
||||
# Install PyTorch 2.6.0 with CUDA 12.4
|
||||
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
||||
# Install PyTorch 2.7.0 with CUDA 12.4
|
||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
||||
```
|
||||
|
||||
### Step 3: Install Dependencies
|
||||
@ -40,7 +40,7 @@ pip install -r requirements.txt
|
||||
|
||||
### Step 4: Optional Performance Optimizations
|
||||
|
||||
#### Sage Attention (30% faster)
|
||||
#### Sage Attention (30% faster), don't install with RTX 50xx as it is not compatible
|
||||
|
||||
```shell
|
||||
# Windows only: Install Triton
|
||||
@ -58,6 +58,7 @@ pip install triton-windows
|
||||
pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl
|
||||
|
||||
# Linux (manual compilation required)
|
||||
python -m pip install "setuptools<=75.8.2" --force-reinstall
|
||||
git clone https://github.com/thu-ml/SageAttention
|
||||
cd SageAttention
|
||||
pip install -e .
|
||||
@ -70,60 +71,6 @@ pip install -e .
|
||||
pip install flash-attn==2.7.2.post1
|
||||
```
|
||||
|
||||
## Installation for RTX 50XX (Beta)
|
||||
|
||||
RTX 50XX GPUs require PyTorch 2.7.0 (beta). This version may be less stable.
|
||||
|
||||
⚠️ **Important:** Use Python 3.10 for compatibility with pip wheels.
|
||||
|
||||
### Step 1: Setup Environment
|
||||
|
||||
```shell
|
||||
# Clone and setup (same as above)
|
||||
git clone https://github.com/deepbeepmeep/Wan2GP.git
|
||||
cd Wan2GP
|
||||
conda create -n wan2gp python=3.10.9
|
||||
conda activate wan2gp
|
||||
```
|
||||
|
||||
### Step 2: Install PyTorch Beta
|
||||
|
||||
```shell
|
||||
# Install PyTorch 2.7.0 with CUDA 12.8
|
||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
||||
```
|
||||
|
||||
### Step 3: Install Dependencies
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Step 4: Optional Optimizations for RTX 50XX
|
||||
|
||||
#### Sage Attention
|
||||
|
||||
```shell
|
||||
# Windows
|
||||
pip install triton-windows
|
||||
pip install sageattention==1.0.6
|
||||
|
||||
# Linux
|
||||
pip install sageattention==1.0.6
|
||||
```
|
||||
|
||||
#### Sage 2 Attention
|
||||
|
||||
```shell
|
||||
# Windows
|
||||
pip install triton-windows
|
||||
pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl
|
||||
|
||||
# Linux (manual compilation)
|
||||
git clone https://github.com/thu-ml/SageAttention
|
||||
cd SageAttention
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Attention Modes
|
||||
|
||||
@ -134,6 +81,12 @@ WanGP supports several attention implementations:
|
||||
- **Sage2**: 40% speed boost
|
||||
- **Flash**: Good performance, may be complex to install on Windows
|
||||
|
||||
### Attention GPU Compatibility
|
||||
|
||||
- RTX 10XX, 20XX: SDPA
|
||||
- RTX 30XX, 40XX: SDPA, Flash Attention, Xformers, Sage, Sage2
|
||||
- RTX 50XX: SDPA, SDPA, Flash Attention, Xformers, Sage2
|
||||
|
||||
## Performance Profiles
|
||||
|
||||
Choose a profile based on your hardware:
|
||||
@ -161,10 +114,5 @@ If Sage attention doesn't work:
|
||||
- Use Profile 4 for lower VRAM usage
|
||||
- Consider using 1.3B models instead of 14B models
|
||||
|
||||
### GPU Compatibility
|
||||
|
||||
- RTX 10XX, 20XX: Supported with SDPA attention
|
||||
- RTX 30XX, 40XX: Full feature support
|
||||
- RTX 50XX: Beta support with PyTorch 2.7.0
|
||||
|
||||
For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md)
|
||||
@ -65,11 +65,16 @@ For dynamic effects over generation steps, use comma-separated values:
|
||||
|
||||
With models like Wan 2.2 that uses internally two diffusion models (*High noise* / *Low Noise*) you can specify which Loras you want to be applied for a specific phase by separating each phase with a ";".
|
||||
|
||||
For instance, if you want to disable a lora for phase *High Noise* and enablesit only for phase *Low Noise*:
|
||||
For instance, if you want to disable a lora for phase *High Noise* and enables it only for phase *Low Noise*:
|
||||
```
|
||||
0;1
|
||||
```
|
||||
|
||||
Also with Wan 2.2, if you have two loras and you want the first one to be applied only during the High noise and the second one during the Low noise phase:
|
||||
```
|
||||
1;0 0;1
|
||||
```
|
||||
|
||||
As usual, you can use any float for of multiplier and have a multiplier varries throughout one phase for one Lora:
|
||||
```
|
||||
0.9,0.8;1.2,1.1,1
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
try:
|
||||
from ._version import (
|
||||
version as __version__, # type: ignore
|
||||
version_tuple,
|
||||
)
|
||||
except ImportError:
|
||||
__version__ = "unknown (no version information available)"
|
||||
version_tuple = (0, 0, "unknown", "noinfo")
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
PACKAGE = __package__.replace("_", "-")
|
||||
PACKAGE_ROOT = Path(__file__).parent
|
||||
682
i2v_inference.py
@ -1,682 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import traceback
|
||||
import gc
|
||||
import random
|
||||
|
||||
# These imports rely on your existing code structure
|
||||
# They must match the location of your WAN code, etc.
|
||||
import wan
|
||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
||||
from wan.modules.attention import get_attention_modes
|
||||
from wan.utils.utils import cache_video
|
||||
from mmgp import offload, safetensors2, profile_type
|
||||
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
DATA_DIR = "ckpts"
|
||||
|
||||
# --------------------------------------------------
|
||||
# HELPER FUNCTIONS
|
||||
# --------------------------------------------------
|
||||
|
||||
def sanitize_file_name(file_name):
|
||||
"""Clean up file name from special chars."""
|
||||
return (
|
||||
file_name.replace("/", "")
|
||||
.replace("\\", "")
|
||||
.replace(":", "")
|
||||
.replace("|", "")
|
||||
.replace("?", "")
|
||||
.replace("<", "")
|
||||
.replace(">", "")
|
||||
.replace('"', "")
|
||||
)
|
||||
|
||||
def extract_preset(lset_name, lora_dir, loras):
|
||||
"""
|
||||
Load a .lset JSON that lists the LoRA files to apply, plus multipliers
|
||||
and possibly a suggested prompt prefix.
|
||||
"""
|
||||
lset_name = sanitize_file_name(lset_name)
|
||||
if not lset_name.endswith(".lset"):
|
||||
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
|
||||
else:
|
||||
lset_name_filename = os.path.join(lora_dir, lset_name)
|
||||
|
||||
if not os.path.isfile(lset_name_filename):
|
||||
raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")
|
||||
|
||||
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
lset = json.loads(text)
|
||||
|
||||
loras_choices_files = lset["loras"]
|
||||
loras_choices = []
|
||||
missing_loras = []
|
||||
for lora_file in loras_choices_files:
|
||||
# Build absolute path and see if it is in loras
|
||||
full_lora_path = os.path.join(lora_dir, lora_file)
|
||||
if full_lora_path in loras:
|
||||
idx = loras.index(full_lora_path)
|
||||
loras_choices.append(str(idx))
|
||||
else:
|
||||
missing_loras.append(lora_file)
|
||||
|
||||
if len(missing_loras) > 0:
|
||||
missing_list = ", ".join(missing_loras)
|
||||
raise ValueError(f"Missing LoRA files for preset: {missing_list}")
|
||||
|
||||
loras_mult_choices = lset["loras_mult"]
|
||||
prompt_prefix = lset.get("prompt", "")
|
||||
full_prompt = lset.get("full_prompt", False)
|
||||
return loras_choices, loras_mult_choices, prompt_prefix, full_prompt
|
||||
|
||||
def get_attention_mode(args_attention, installed_modes):
|
||||
"""
|
||||
Decide which attention mode to use: either the user choice or auto fallback.
|
||||
"""
|
||||
if args_attention == "auto":
|
||||
for candidate in ["sage2", "sage", "sdpa"]:
|
||||
if candidate in installed_modes:
|
||||
return candidate
|
||||
return "sdpa" # last fallback
|
||||
elif args_attention in installed_modes:
|
||||
return args_attention
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Requested attention mode '{args_attention}' not installed. "
|
||||
f"Installed modes: {installed_modes}"
|
||||
)
|
||||
|
||||
def load_i2v_model(model_filename, text_encoder_filename, is_720p):
|
||||
"""
|
||||
Load the i2v model with a specific size config and text encoder.
|
||||
"""
|
||||
if is_720p:
|
||||
print("Loading 14B-720p i2v model ...")
|
||||
cfg = WAN_CONFIGS['i2v-14B']
|
||||
wan_model = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=DATA_DIR,
|
||||
model_filename=model_filename,
|
||||
text_encoder_filename=text_encoder_filename
|
||||
)
|
||||
else:
|
||||
print("Loading 14B-480p i2v model ...")
|
||||
cfg = WAN_CONFIGS['i2v-14B']
|
||||
wan_model = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=DATA_DIR,
|
||||
model_filename=model_filename,
|
||||
text_encoder_filename=text_encoder_filename
|
||||
)
|
||||
# Pipe structure
|
||||
pipe = {
|
||||
"transformer": wan_model.model,
|
||||
"text_encoder": wan_model.text_encoder.model,
|
||||
"text_encoder_2": wan_model.clip.model,
|
||||
"vae": wan_model.vae.model
|
||||
}
|
||||
return wan_model, pipe
|
||||
|
||||
def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
|
||||
"""
|
||||
Load loras from a directory, optionally apply a preset.
|
||||
"""
|
||||
from pathlib import Path
|
||||
import glob
|
||||
|
||||
if not lora_dir or not Path(lora_dir).is_dir():
|
||||
print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
|
||||
return [], [], [], "", "", False
|
||||
|
||||
# Gather LoRA files
|
||||
loras = sorted(
|
||||
glob.glob(os.path.join(lora_dir, "*.sft"))
|
||||
+ glob.glob(os.path.join(lora_dir, "*.safetensors"))
|
||||
)
|
||||
loras_names = [Path(x).stem for x in loras]
|
||||
|
||||
# Offload them with no activation
|
||||
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)
|
||||
|
||||
# If user gave a preset, apply it
|
||||
default_loras_choices = []
|
||||
default_loras_multis_str = ""
|
||||
default_prompt_prefix = ""
|
||||
preset_applied_full_prompt = False
|
||||
if lora_preset:
|
||||
loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
|
||||
default_loras_choices = loras_choices
|
||||
# If user stored loras_mult as a list or string in JSON, unify that to str
|
||||
if isinstance(loras_mult, list):
|
||||
# Just store them in a single line
|
||||
default_loras_multis_str = " ".join([str(x) for x in loras_mult])
|
||||
else:
|
||||
default_loras_multis_str = str(loras_mult)
|
||||
default_prompt_prefix = prefix
|
||||
preset_applied_full_prompt = full_prompt
|
||||
|
||||
return (
|
||||
loras,
|
||||
loras_names,
|
||||
default_loras_choices,
|
||||
default_loras_multis_str,
|
||||
default_prompt_prefix,
|
||||
preset_applied_full_prompt
|
||||
)
|
||||
|
||||
def parse_loras_and_activate(
|
||||
transformer,
|
||||
loras,
|
||||
loras_choices,
|
||||
loras_mult_str,
|
||||
num_inference_steps
|
||||
):
|
||||
"""
|
||||
Activate the chosen LoRAs with multipliers over the pipeline's transformer.
|
||||
Supports stepwise expansions (like "0.5,0.8" for partial steps).
|
||||
"""
|
||||
if not loras or not loras_choices:
|
||||
# no LoRAs selected
|
||||
return
|
||||
|
||||
# Handle multipliers
|
||||
def is_float_or_comma_list(x):
|
||||
"""
|
||||
Example: "0.5", or "0.8,1.0", etc. is valid.
|
||||
"""
|
||||
if not x:
|
||||
return False
|
||||
for chunk in x.split(","):
|
||||
try:
|
||||
float(chunk.strip())
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Convert multiline or spaced lines to a single list
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in loras_mult_str.replace("\r", "\n").split("\n")
|
||||
if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
# Now combine them by space
|
||||
joined_line = " ".join(lines) # "1.0 2.0,3.0"
|
||||
if not joined_line.strip():
|
||||
multipliers = []
|
||||
else:
|
||||
multipliers = joined_line.split(" ")
|
||||
|
||||
# Expand each item
|
||||
final_multipliers = []
|
||||
for mult in multipliers:
|
||||
mult = mult.strip()
|
||||
if not mult:
|
||||
continue
|
||||
if is_float_or_comma_list(mult):
|
||||
# Could be "0.7" or "0.5,0.6"
|
||||
if "," in mult:
|
||||
# expand over steps
|
||||
chunk_vals = [float(x.strip()) for x in mult.split(",")]
|
||||
expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
|
||||
final_multipliers.append(expanded)
|
||||
else:
|
||||
final_multipliers.append(float(mult))
|
||||
else:
|
||||
raise ValueError(f"Invalid LoRA multiplier: '{mult}'")
|
||||
|
||||
# If fewer multipliers than chosen LoRAs => pad with 1.0
|
||||
needed = len(loras_choices) - len(final_multipliers)
|
||||
if needed > 0:
|
||||
final_multipliers += [1.0]*needed
|
||||
|
||||
# Actually activate them
|
||||
offload.activate_loras(transformer, loras_choices, final_multipliers)
|
||||
|
||||
def expand_list_over_steps(short_list, num_steps):
|
||||
"""
|
||||
If user gave (0.5, 0.8) for example, expand them over `num_steps`.
|
||||
The expansion is simply linear slice across steps.
|
||||
"""
|
||||
result = []
|
||||
inc = len(short_list) / float(num_steps)
|
||||
idxf = 0.0
|
||||
for _ in range(num_steps):
|
||||
value = short_list[int(idxf)]
|
||||
result.append(value)
|
||||
idxf += inc
|
||||
return result
|
||||
|
||||
def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
|
||||
"""
|
||||
Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.
|
||||
If not, downloads them from a Hugging Face Hub repo.
|
||||
Adjust the 'repo_id' and needed files as appropriate.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"huggingface_hub is required for automatic model download. "
|
||||
"Please install it via `pip install huggingface_hub`."
|
||||
) from e
|
||||
|
||||
# Identify just the filename portion for each path
|
||||
def basename(path_str):
|
||||
return os.path.basename(path_str)
|
||||
|
||||
repo_id = "DeepBeepMeep/Wan2.1"
|
||||
target_root = local_folder
|
||||
|
||||
# You can customize this list as needed for i2v usage.
|
||||
# At minimum you need:
|
||||
# 1) The requested i2v transformer file
|
||||
# 2) The requested text encoder file
|
||||
# 3) VAE file
|
||||
# 4) The open-clip xlm-roberta-large weights
|
||||
#
|
||||
# If your i2v config references additional files, add them here.
|
||||
needed_files = [
|
||||
"Wan2.1_VAE.pth",
|
||||
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
basename(text_encoder_filename),
|
||||
basename(transformer_filename_i2v),
|
||||
]
|
||||
|
||||
# The original script also downloads an entire "xlm-roberta-large" folder
|
||||
# via snapshot_download. If you require that for your pipeline,
|
||||
# you can add it here, for example:
|
||||
subfolder_name = "xlm-roberta-large"
|
||||
if not Path(os.path.join(target_root, subfolder_name)).exists():
|
||||
snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)
|
||||
|
||||
for filename in needed_files:
|
||||
local_path = os.path.join(target_root, filename)
|
||||
if not os.path.isfile(local_path):
|
||||
print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
local_dir=target_root
|
||||
)
|
||||
else:
|
||||
# Already present
|
||||
pass
|
||||
|
||||
print("All required i2v files are present.")
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# ARGUMENT PARSER
|
||||
# --------------------------------------------------
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Image-to-Video inference using WAN 2.1 i2v"
|
||||
)
|
||||
# Model + Tools
|
||||
parser.add_argument(
|
||||
"--quantize-transformer",
|
||||
action="store_true",
|
||||
help="Use on-the-fly transformer quantization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="Enable PyTorch 2.0 compile for the transformer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Which attention to use: auto, sdpa, sage, sage2, flash"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preload",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Verbosity level [0..5]"
|
||||
)
|
||||
|
||||
# i2v Model
|
||||
parser.add_argument(
|
||||
"--transformer-file",
|
||||
type=str,
|
||||
default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
|
||||
help="Which i2v model to load"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-encoder-file",
|
||||
type=str,
|
||||
default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
|
||||
help="Which text encoder to use"
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to a directory containing i2v LoRAs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-preset",
|
||||
type=str,
|
||||
default="",
|
||||
help="A .lset preset name in the lora_dir to auto-apply"
|
||||
)
|
||||
|
||||
# Generation Options
|
||||
parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
|
||||
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
||||
parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
|
||||
parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
|
||||
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
|
||||
parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
|
||||
parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
|
||||
parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
|
||||
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
||||
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
||||
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
||||
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
||||
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
||||
|
||||
# LoRA usage
|
||||
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
||||
parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")
|
||||
|
||||
# Input
|
||||
parser.add_argument(
|
||||
"--input-image",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to an input image (or multiple)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Where to save the resulting video."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
# --------------------------------------------------
|
||||
# MAIN
|
||||
# --------------------------------------------------
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup environment
|
||||
offload.default_verboseLevel = args.verbose
|
||||
installed_attn_modes = get_attention_modes()
|
||||
|
||||
# Decide attention
|
||||
chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
|
||||
offload.shared_state["_attention"] = chosen_attention
|
||||
|
||||
# Determine i2v resolution format
|
||||
if "720" in args.transformer_file:
|
||||
is_720p = True
|
||||
else:
|
||||
is_720p = False
|
||||
|
||||
# Make sure we have the needed models locally
|
||||
download_models_if_needed(args.transformer_file, args.text_encoder_file)
|
||||
|
||||
# Load i2v
|
||||
wan_model, pipe = load_i2v_model(
|
||||
model_filename=args.transformer_file,
|
||||
text_encoder_filename=args.text_encoder_file,
|
||||
is_720p=is_720p
|
||||
)
|
||||
wan_model._interrupt = False
|
||||
|
||||
# Offload / profile
|
||||
# e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
|
||||
# pass the budgets if you want, etc.
|
||||
kwargs = {}
|
||||
if args.profile == 2 or args.profile == 4:
|
||||
# preload is in MB
|
||||
if args.preload == 0:
|
||||
budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
|
||||
else:
|
||||
budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
|
||||
kwargs["budgets"] = budgets
|
||||
elif args.profile == 3:
|
||||
kwargs["budgets"] = {"*": "70%"}
|
||||
|
||||
compile_choice = "transformer" if args.compile else ""
|
||||
# Create the offload object
|
||||
offloadobj = offload.profile(
|
||||
pipe,
|
||||
profile_no=args.profile,
|
||||
compile=compile_choice,
|
||||
quantizeTransformer=args.quantize_transformer,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# If user wants to use LoRAs
|
||||
(
|
||||
loras,
|
||||
loras_names,
|
||||
default_loras_choices,
|
||||
default_loras_multis_str,
|
||||
preset_prompt_prefix,
|
||||
preset_full_prompt
|
||||
) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)
|
||||
|
||||
# Combine user prompt with preset prompt if the preset indicates so
|
||||
if preset_prompt_prefix:
|
||||
if preset_full_prompt:
|
||||
# Full override
|
||||
user_prompt = preset_prompt_prefix
|
||||
else:
|
||||
# Just prefix
|
||||
user_prompt = preset_prompt_prefix + "\n" + args.prompt
|
||||
else:
|
||||
user_prompt = args.prompt
|
||||
|
||||
# Actually parse user LoRA choices if they did not rely purely on the preset
|
||||
if args.loras_choices:
|
||||
# If user gave e.g. "0,1", we treat that as new additions
|
||||
lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
|
||||
else:
|
||||
# Use the defaults from the preset
|
||||
lora_choice_list = default_loras_choices
|
||||
|
||||
# Activate them
|
||||
parse_loras_and_activate(
|
||||
pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
|
||||
)
|
||||
|
||||
# Negative prompt
|
||||
negative_prompt = args.negative_prompt or ""
|
||||
|
||||
# Sanity check resolution
|
||||
if "*" in args.resolution.lower():
|
||||
print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
|
||||
resolution_str = args.resolution.lower().replace("*", "x")
|
||||
else:
|
||||
resolution_str = args.resolution
|
||||
|
||||
try:
|
||||
width, height = [int(x) for x in resolution_str.split("x")]
|
||||
except:
|
||||
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
||||
|
||||
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
||||
if args.slg_layers:
|
||||
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
||||
else:
|
||||
slg_list = None
|
||||
|
||||
# Additional checks (from your original code).
|
||||
if "480p" in args.transformer_file:
|
||||
# Then we cannot exceed certain area for 480p model
|
||||
if width * height > 832*480:
|
||||
raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
|
||||
# etc.
|
||||
|
||||
# Handle random seed
|
||||
if args.seed < 0:
|
||||
args.seed = random.randint(0, 999999999)
|
||||
print(f"Using seed={args.seed}")
|
||||
|
||||
# Setup tea cache if needed
|
||||
trans = wan_model.model
|
||||
trans.enable_cache = (args.teacache > 0)
|
||||
if trans.enable_cache:
|
||||
if "480p" in args.transformer_file:
|
||||
# example from your code
|
||||
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||
elif "720p" in args.transformer_file:
|
||||
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||
else:
|
||||
raise ValueError("Teacache not supported for this model variant")
|
||||
|
||||
# Attempt generation
|
||||
print("Starting generation ...")
|
||||
start_time = time.time()
|
||||
|
||||
# Read the input image
|
||||
if not os.path.isfile(args.input_image):
|
||||
raise ValueError(f"Input image does not exist: {args.input_image}")
|
||||
|
||||
from PIL import Image
|
||||
input_img = Image.open(args.input_image).convert("RGB")
|
||||
|
||||
# Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration
|
||||
|
||||
# Define the generation call
|
||||
# - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
|
||||
# You can correct to that if needed:
|
||||
frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1
|
||||
# RIFLEx
|
||||
enable_riflex = args.riflex
|
||||
|
||||
# If teacache => reset counters
|
||||
if trans.enable_cache:
|
||||
trans.teacache_counter = 0
|
||||
trans.cache_multiplier = args.teacache
|
||||
trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
|
||||
trans.num_steps = args.steps
|
||||
trans.cache_skipped_steps = 0
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
|
||||
# VAE Tiling
|
||||
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
||||
if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM
|
||||
use_vae_config = 1
|
||||
elif device_mem_capacity >= 8000:
|
||||
use_vae_config = 2
|
||||
else:
|
||||
use_vae_config = 3
|
||||
|
||||
if use_vae_config == 1:
|
||||
VAE_tile_size = 0
|
||||
elif use_vae_config == 2:
|
||||
VAE_tile_size = 256
|
||||
else:
|
||||
VAE_tile_size = 128
|
||||
|
||||
print('Using VAE tile size of', VAE_tile_size)
|
||||
|
||||
# Actually run the i2v generation
|
||||
try:
|
||||
sample_frames = wan_model.generate(
|
||||
input_prompt = user_prompt,
|
||||
image_start = input_img,
|
||||
frame_num=frame_count,
|
||||
width=width,
|
||||
height=height,
|
||||
# max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom
|
||||
shift=args.flow_shift,
|
||||
sampling_steps=args.steps,
|
||||
guide_scale=args.guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=args.seed,
|
||||
offload_model=False,
|
||||
callback=None, # or define your own callback if you want
|
||||
enable_RIFLEx=enable_riflex,
|
||||
VAE_tile_size=VAE_tile_size,
|
||||
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
||||
slg_layers=slg_list,
|
||||
slg_start=args.slg_start,
|
||||
slg_end=args.slg_end,
|
||||
)
|
||||
except Exception as e:
|
||||
offloadobj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
err_str = f"Generation failed with error: {e}"
|
||||
# Attempt to detect OOM errors
|
||||
s = str(e).lower()
|
||||
if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
|
||||
raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
|
||||
else:
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(err_str)
|
||||
|
||||
# After generation
|
||||
offloadobj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if sample_frames is None:
|
||||
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
||||
|
||||
# If teacache was used, we can see how many steps were skipped
|
||||
if trans.enable_cache:
|
||||
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
||||
|
||||
# Save result
|
||||
sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W]
|
||||
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
|
||||
|
||||
# Use the provided helper from your code to store the MP4
|
||||
# By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
|
||||
# or you can do your own. We'll do the same for consistency:
|
||||
cache_video(
|
||||
tensor=sample_frames[None], # shape => [1, c, T, H, W]
|
||||
save_file=args.output_file,
|
||||
fps=16,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_s = end_time - start_time
|
||||
print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
loras_qwen/Readme.txt
Normal file
@ -0,0 +1 @@
|
||||
LTX Video loras
|
||||
2
models/flux/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .flux_main import model_factory
|
||||
from . import flux_handler
|
||||
103
models/flux/flux_handler.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
|
||||
def get_ltxv_text_encoder_filename(text_encoder_quantization):
|
||||
text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors"
|
||||
if text_encoder_quantization =="int8":
|
||||
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
|
||||
return text_encoder_filename
|
||||
|
||||
class family_handler():
|
||||
@staticmethod
|
||||
def query_model_def(base_model_type, model_def):
|
||||
flux_model = model_def.get("flux-model", "flux-dev")
|
||||
flux_schnell = flux_model == "flux-schnell"
|
||||
model_def_output = {
|
||||
"image_outputs" : True,
|
||||
"no_negative_prompt" : True,
|
||||
}
|
||||
if flux_schnell:
|
||||
model_def_output["no_guidance"] = True
|
||||
else:
|
||||
model_def_output["embedded_guidance"] = True
|
||||
|
||||
|
||||
return model_def_output
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
return ["flux"]
|
||||
|
||||
@staticmethod
|
||||
def query_family_maps():
|
||||
return {}, {}
|
||||
|
||||
@staticmethod
|
||||
def get_rgb_factors(base_model_type ):
|
||||
from shared.RGB_factors import get_rgb_factors
|
||||
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux")
|
||||
return latent_rgb_factors, latent_rgb_factors_bias
|
||||
|
||||
|
||||
@staticmethod
|
||||
def query_model_family():
|
||||
return "flux"
|
||||
|
||||
@staticmethod
|
||||
def query_family_infos():
|
||||
return {"flux":(30, "Flux 1")}
|
||||
|
||||
@staticmethod
|
||||
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
||||
text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization)
|
||||
return [
|
||||
{
|
||||
"repoId" : "DeepBeepMeep/Flux",
|
||||
"sourceFolderList" : [""],
|
||||
"fileList" : [ ["flux_vae.safetensors"] ]
|
||||
},
|
||||
{
|
||||
"repoId" : "DeepBeepMeep/LTX_Video",
|
||||
"sourceFolderList" : ["T5_xxl_1.1"],
|
||||
"fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ]
|
||||
},
|
||||
{
|
||||
"repoId" : "DeepBeepMeep/HunyuanVideo",
|
||||
"sourceFolderList" : [ "clip_vit_large_patch14", ],
|
||||
"fileList" :[
|
||||
["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"],
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
|
||||
from .flux_main import model_factory
|
||||
|
||||
flux_model = model_factory(
|
||||
checkpoint_dir="ckpts",
|
||||
model_filename=model_filename,
|
||||
model_type = model_type,
|
||||
model_def = model_def,
|
||||
base_model_type=base_model_type,
|
||||
text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization),
|
||||
quantizeTransformer = quantizeTransformer,
|
||||
dtype = dtype,
|
||||
VAE_dtype = VAE_dtype,
|
||||
mixed_precision_transformer = mixed_precision_transformer,
|
||||
save_quantized = save_quantized
|
||||
)
|
||||
|
||||
pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5}
|
||||
|
||||
return flux_model, pipe
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
ui_defaults.update({
|
||||
"embedded_guidance": 2.5,
|
||||
})
|
||||
if model_def.get("reference_image", False):
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "KI",
|
||||
})
|
||||
|
||||
@ -5,10 +5,10 @@ from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
from mmgp import offload as offload
|
||||
import torch
|
||||
from wan.utils.utils import calculate_new_dimensions
|
||||
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
|
||||
from flux.modules.layers import get_linear_split_map
|
||||
from flux.util import (
|
||||
from shared.utils.utils import calculate_new_dimensions
|
||||
from .sampling import denoise, get_schedule, prepare_kontext, unpack
|
||||
from .modules.layers import get_linear_split_map
|
||||
from .util import (
|
||||
aspect_ratio_to_height_width,
|
||||
load_ae,
|
||||
load_clip,
|
||||
@ -146,13 +146,3 @@ class model_factory:
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
def query_model_def(model_type, model_def):
|
||||
flux_model = model_def.get("flux-model", "flux-dev")
|
||||
flux_schnell = flux_model == "flux-schnell"
|
||||
model_def_output = {
|
||||
"image_outputs" : True,
|
||||
}
|
||||
if flux_schnell:
|
||||
model_def_output["no_guidance"] = True
|
||||
|
||||
return model_def_output
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from wan.modules.attention import pay_attention
|
||||
from shared.attention import pay_attention
|
||||
|
||||
|
||||
def attention(qkv_list, pe: Tensor) -> Tensor:
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from flux.modules.layers import (
|
||||
from .modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
@ -11,7 +11,7 @@ from flux.modules.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
from flux.modules.lora import LinearLora, replace_linear_with_lora
|
||||
from .modules.lora import LinearLora, replace_linear_with_lora
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -7,7 +7,7 @@ from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from flux.util import print_load_warning
|
||||
from ..util import print_load_warning
|
||||
|
||||
|
||||
class DepthImageEncoder:
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from flux.math import attention, rope
|
||||
from ..math import attention, rope
|
||||
|
||||
def get_linear_split_map():
|
||||
hidden_size = 3072
|
||||
@ -343,7 +343,7 @@ def denoise(
|
||||
|
||||
updated_num_steps= len(timesteps) -1
|
||||
if callback != None:
|
||||
from wan.utils.loras_mutipliers import update_loras_slists
|
||||
from shared.utils.loras_mutipliers import update_loras_slists
|
||||
update_loras_slists(model, loras_slists, updated_num_steps)
|
||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||
from mmgp import offload
|
||||
@ -11,9 +11,9 @@ from huggingface_hub import hf_hub_download, login
|
||||
from PIL import ExifTags, Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from flux.model import Flux, FluxLoraWrapper, FluxParams
|
||||
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from flux.modules.conditioner import HFEmbedder
|
||||
from .model import Flux, FluxLoraWrapper, FluxParams
|
||||
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
CHECKPOINTS_DIR = Path("checkpoints")
|
||||
|
||||
2
models/hyvideo/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .hunyuan import HunyuanVideoSampler
|
||||
from . import hunyuan_handler
|
||||
@ -949,15 +949,18 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
trans = self.transformer
|
||||
if trans.enable_cache == "tea":
|
||||
teacache_multiplier = trans.cache_multiplier
|
||||
trans.accumulated_rel_l1_distance = 0
|
||||
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
elif trans.enable_cache == "mag":
|
||||
trans.compute_magcache_threshold(trans.cache_start_step, num_inference_steps, trans.cache_multiplier)
|
||||
trans.accumulated_err, trans.accumulated_steps, trans.accumulated_ratio = 0, 0, 1.0
|
||||
else:
|
||||
trans.enable_cache == None
|
||||
skip_steps_cache = trans.cache
|
||||
if skip_steps_cache != None:
|
||||
cache_type = skip_steps_cache.cache_type
|
||||
if cache_type == "tea":
|
||||
teacache_multiplier = skip_steps_cache.multiplier
|
||||
skip_steps_cache.accumulated_rel_l1_distance = 0
|
||||
skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
elif cache_type== "mag":
|
||||
trans.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier)
|
||||
skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0
|
||||
else:
|
||||
trans.cache = None
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
@ -1212,8 +1215,8 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
if ip_cfg_scale>0:
|
||||
latent_items += 1
|
||||
|
||||
if self.transformer.enable_cache:
|
||||
self.transformer.previous_residual = [None] * latent_items
|
||||
if skip_steps_cache != None:
|
||||
skip_steps_cache.previous_residual = [None] * latent_items
|
||||
|
||||
# if is_progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@ -41,9 +41,9 @@ from diffusers.utils import (
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
from hyvideo.constants import PRECISION_TO_TYPE
|
||||
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
||||
from hyvideo.text_encoder import TextEncoder
|
||||
from ...constants import PRECISION_TO_TYPE
|
||||
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
||||
from ...text_encoder import TextEncoder
|
||||
from einops import rearrange
|
||||
from ...modules import HYVideoDiffusionTransformer
|
||||
|
||||
@ -934,15 +934,20 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
|
||||
transformer = self.transformer
|
||||
|
||||
if transformer.enable_cache == "tea":
|
||||
teacache_multiplier = transformer.cache_multiplier
|
||||
transformer.accumulated_rel_l1_distance = 0
|
||||
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
elif transformer.enable_cache == "mag":
|
||||
transformer.compute_magcache_threshold(transformer.cache_start_step, num_inference_steps, transformer.cache_multiplier)
|
||||
transformer.accumulated_err, transformer.accumulated_steps, transformer.accumulated_ratio = 0, 0, 1.0
|
||||
else:
|
||||
transformer.enable_cache == None
|
||||
skip_steps_cache = transformer.cache
|
||||
cache_type = None
|
||||
if skip_steps_cache != None:
|
||||
cache_type = skip_steps_cache.cache_type
|
||||
if cache_type == "tea":
|
||||
teacache_multiplier = skip_steps_cache.multiplier
|
||||
skip_steps_cache.accumulated_rel_l1_distance = 0
|
||||
skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
|
||||
elif cache_type == "mag":
|
||||
transformer.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier)
|
||||
skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0
|
||||
else:
|
||||
transformer.cache = None
|
||||
cache_type = None
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@ -1141,16 +1146,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
if self._interrupt:
|
||||
return [None]
|
||||
|
||||
if transformer.enable_cache == "tea":
|
||||
if cache_type == "tea":
|
||||
cache_size = round( infer_length / frames_per_batch )
|
||||
transformer.previous_residual = [None] * latent_items
|
||||
skip_steps_cache.previous_residual = [None] * latent_items
|
||||
cache_all_previous_residual = [None] * latent_items
|
||||
cache_all_previous_modulated_input = None
|
||||
cache_should_calc = [True] * cache_size
|
||||
cache_accumulated_rel_l1_distance = [0.] * cache_size
|
||||
cache_teacache_skipped_steps = [0] * cache_size
|
||||
elif transformer.enable_cache == "mag":
|
||||
transformer.previous_residual = [None] * latent_items
|
||||
elif cache_type == "mag":
|
||||
skip_steps_cache.previous_residual = [None] * latent_items
|
||||
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@ -1187,16 +1192,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
|
||||
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
|
||||
|
||||
if transformer.enable_cache == "tea" and cache_size > 1:
|
||||
if cache_type == "tea" and cache_size > 1:
|
||||
for l in range(latent_items):
|
||||
if cache_all_previous_residual[l] != None:
|
||||
bsz = cache_all_previous_residual[l].shape[0]
|
||||
transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
||||
skip_steps_cache.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
||||
if cache_all_previous_modulated_input != None:
|
||||
transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
||||
transformer.should_calc = cache_should_calc[cache_slot_no]
|
||||
transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no]
|
||||
transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no]
|
||||
skip_steps_cache.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
|
||||
skip_steps_cache.should_calc = cache_should_calc[cache_slot_no]
|
||||
skip_steps_cache.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no]
|
||||
skip_steps_cache.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no]
|
||||
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
@ -1304,21 +1309,21 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
|
||||
pred_latents[:, :, p] += latents[:, :, iii]
|
||||
counter[:, :, p] += 1
|
||||
|
||||
if transformer.enable_cache == "tea" and cache_size > 1:
|
||||
if cache_type == "tea" and cache_size > 1:
|
||||
for l in range(latent_items):
|
||||
if transformer.previous_residual[l] != None:
|
||||
bsz = transformer.previous_residual[l].shape[0]
|
||||
if skip_steps_cache.previous_residual[l] != None:
|
||||
bsz = skip_steps_cache.previous_residual[l].shape[0]
|
||||
if cache_all_previous_residual[l] == None:
|
||||
cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype)
|
||||
cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw)
|
||||
cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=skip_steps_cache.previous_residual[l].device, dtype=skip_steps_cache.previous_residual[l].dtype)
|
||||
cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw)
|
||||
|
||||
if transformer.previous_modulated_input != None:
|
||||
if skip_steps_cache.previous_modulated_input != None:
|
||||
if cache_all_previous_modulated_input == None:
|
||||
cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype)
|
||||
cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw)
|
||||
cache_should_calc[cache_slot_no] = transformer.should_calc
|
||||
cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance
|
||||
cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps
|
||||
cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=skip_steps_cache.previous_modulated_input.device, dtype=skip_steps_cache.previous_modulated_input.dtype)
|
||||
cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw)
|
||||
cache_should_calc[cache_slot_no] = skip_steps_cache.should_calc
|
||||
cache_accumulated_rel_l1_distance[cache_slot_no] = skip_steps_cache.accumulated_rel_l1_distance
|
||||
cache_teacache_skipped_steps[cache_slot_no] = skip_steps_cache.teacache_skipped_steps
|
||||
|
||||
cache_slot_no += 1
|
||||
|
||||
@ -8,24 +8,24 @@ from pathlib import Path
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
|
||||
from hyvideo.vae import load_vae
|
||||
from hyvideo.modules import load_model
|
||||
from hyvideo.text_encoder import TextEncoder
|
||||
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
|
||||
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
|
||||
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
||||
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
||||
from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline
|
||||
from .constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
|
||||
from .vae import load_vae
|
||||
from .modules import load_model
|
||||
from .text_encoder import TextEncoder
|
||||
from .utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
|
||||
from .modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
|
||||
from .diffusion.schedulers import FlowMatchDiscreteScheduler
|
||||
from .diffusion.pipelines import HunyuanVideoPipeline
|
||||
from .diffusion.pipelines import HunyuanVideoAudioPipeline
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
import cv2
|
||||
from wan.utils.utils import calculate_new_dimensions, convert_tensor_to_image
|
||||
from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask
|
||||
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
|
||||
from .data_kits.audio_preprocessor import encode_audio, get_facemask
|
||||
from transformers import WhisperModel
|
||||
from transformers import AutoFeatureExtractor
|
||||
from hyvideo.data_kits.face_align import AlignImage
|
||||
from .data_kits.face_align import AlignImage
|
||||
import librosa
|
||||
|
||||
def get_audio_feature(feature_extractor, audio_path, duration):
|
||||
@ -66,174 +66,174 @@ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
|
||||
|
||||
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
# def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
# num_images, num_image_patches, embed_dim = image_features.shape
|
||||
# batch_size, sequence_length = input_ids.shape
|
||||
# left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# # 1. Create a mask to know where special image tokens are
|
||||
# special_image_token_mask = input_ids == self.config.image_token_index
|
||||
# num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# # Compute the maximum embed dimension
|
||||
# max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
# batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
# # 2. Compute the positions where text should be written
|
||||
# # Calculate new positions for text tokens in merged image-text sequence.
|
||||
# # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# # `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
# new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
# nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
# if left_padding:
|
||||
# new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
# text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
# # 3. Create the full embedding, already padded to the maximum position
|
||||
# final_embedding = torch.zeros(
|
||||
# batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
# )
|
||||
# final_attention_mask = torch.zeros(
|
||||
# batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
# )
|
||||
# if labels is not None:
|
||||
# final_labels = torch.full(
|
||||
# (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
# )
|
||||
# # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# # set the corresponding tensors into their correct target device.
|
||||
# target_device = inputs_embeds.device
|
||||
# batch_indices, non_image_indices, text_to_overwrite = (
|
||||
# batch_indices.to(target_device),
|
||||
# non_image_indices.to(target_device),
|
||||
# text_to_overwrite.to(target_device),
|
||||
# )
|
||||
# attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
# # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
# final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
# final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
# if labels is not None:
|
||||
# final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
# # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
# image_to_overwrite = torch.full(
|
||||
# (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
# )
|
||||
# image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
# image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
# if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
# raise ValueError(
|
||||
# f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
# f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
# )
|
||||
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
# final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
# final_attention_mask |= image_to_overwrite
|
||||
# position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
# # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
# batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
# indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
# final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
# if labels is None:
|
||||
# final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
# return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
def patched_llava_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
):
|
||||
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
||||
# def patched_llava_forward(
|
||||
# self,
|
||||
# input_ids: torch.LongTensor = None,
|
||||
# pixel_values: torch.FloatTensor = None,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# position_ids: Optional[torch.LongTensor] = None,
|
||||
# past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
# inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
# vision_feature_layer: Optional[int] = None,
|
||||
# vision_feature_select_strategy: Optional[str] = None,
|
||||
# labels: Optional[torch.LongTensor] = None,
|
||||
# use_cache: Optional[bool] = None,
|
||||
# output_attentions: Optional[bool] = None,
|
||||
# output_hidden_states: Optional[bool] = None,
|
||||
# return_dict: Optional[bool] = None,
|
||||
# cache_position: Optional[torch.LongTensor] = None,
|
||||
# num_logits_to_keep: int = 0,
|
||||
# ):
|
||||
# from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
||||
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
# output_hidden_states = (
|
||||
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
# )
|
||||
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# vision_feature_layer = (
|
||||
# vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
# )
|
||||
# vision_feature_select_strategy = (
|
||||
# vision_feature_select_strategy
|
||||
# if vision_feature_select_strategy is not None
|
||||
# else self.config.vision_feature_select_strategy
|
||||
# )
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
# if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
# raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
# if pixel_values is not None and inputs_embeds is not None:
|
||||
# raise ValueError(
|
||||
# "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
# )
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
# image_features = None
|
||||
# if pixel_values is not None:
|
||||
# image_features = self.get_image_features(
|
||||
# pixel_values=pixel_values,
|
||||
# vision_feature_layer=vision_feature_layer,
|
||||
# vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
# )
|
||||
|
||||
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
# inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
# image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
# )
|
||||
# cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
# outputs = self.language_model(
|
||||
# attention_mask=attention_mask,
|
||||
# position_ids=position_ids,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# use_cache=use_cache,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=output_hidden_states,
|
||||
# return_dict=return_dict,
|
||||
# cache_position=cache_position,
|
||||
# num_logits_to_keep=num_logits_to_keep,
|
||||
# )
|
||||
|
||||
logits = outputs[0]
|
||||
# logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
# loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
# if not return_dict:
|
||||
# output = (logits,) + outputs[1:]
|
||||
# return (loss,) + output if loss is not None else output
|
||||
|
||||
return LlavaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
# return LlavaCausalLMOutputWithPast(
|
||||
# loss=loss,
|
||||
# logits=logits,
|
||||
# past_key_values=outputs.past_key_values,
|
||||
# hidden_states=outputs.hidden_states,
|
||||
# attentions=outputs.attentions,
|
||||
# image_hidden_states=image_features if pixel_values is not None else None,
|
||||
# )
|
||||
|
||||
def adapt_model(model, audio_block_name):
|
||||
modules_dict= { k: m for k, m in model.named_modules()}
|
||||
@ -320,8 +320,8 @@ class Inference(object):
|
||||
device = "cuda"
|
||||
|
||||
import transformers
|
||||
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
|
||||
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
|
||||
# transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
|
||||
# transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
text_len = 512
|
||||
@ -778,7 +778,7 @@ class HunyuanVideoSampler(Inference):
|
||||
raise ValueError(
|
||||
f"Seed must be an integer, a list of integers, or None, got {seed}."
|
||||
)
|
||||
from wan.utils.utils import seed_everything
|
||||
from shared.utils.utils import seed_everything
|
||||
seed_everything(seed)
|
||||
generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
|
||||
# generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
||||
@ -956,7 +956,7 @@ class HunyuanVideoSampler(Inference):
|
||||
# out_latents= ref_latents / self.vae.config.scaling_factor
|
||||
# image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0]
|
||||
# image = image.clamp(-1, 1)
|
||||
# from wan.utils.utils import cache_video
|
||||
# from shared.utils.utils import cache_video
|
||||
# cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
|
||||
|
||||
motion_pose = np.array([25] * 4)
|
||||
@ -1040,5 +1040,4 @@ class HunyuanVideoSampler(Inference):
|
||||
return samples
|
||||
|
||||
|
||||
def query_model_def(model_type, model_def):
|
||||
return None
|
||||
|
||||
167
models/hyvideo/hunyuan_handler.py
Normal file
@ -0,0 +1,167 @@
|
||||
import torch
|
||||
|
||||
def get_hunyuan_text_encoder_filename(text_encoder_quantization):
|
||||
if text_encoder_quantization =="int8":
|
||||
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors"
|
||||
else:
|
||||
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors"
|
||||
|
||||
return text_encoder_filename
|
||||
|
||||
class family_handler():
|
||||
|
||||
@staticmethod
|
||||
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
|
||||
resolution = inputs["resolution"]
|
||||
width, height = resolution.split("x")
|
||||
pixels = int(width) * int(height)
|
||||
|
||||
if cache_type == "mag":
|
||||
skip_steps_cache.update({
|
||||
"magcache_thresh" : 0,
|
||||
"magcache_K" : 2,
|
||||
})
|
||||
if pixels >= 1280* 720:
|
||||
skip_steps_cache.def_mag_ratios = [1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456]
|
||||
else:
|
||||
skip_steps_cache.def_mag_ratios = [1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592]
|
||||
else:
|
||||
skip_steps_cache.coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
|
||||
@staticmethod
|
||||
def query_model_def(base_model_type, model_def):
|
||||
extra_model_def = {}
|
||||
|
||||
if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
|
||||
fps = 25
|
||||
elif base_model_type in ["hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]:
|
||||
fps = 24
|
||||
else:
|
||||
fps = 16
|
||||
extra_model_def["fps"] = fps
|
||||
extra_model_def["frames_minimum"] = 5
|
||||
extra_model_def["frames_steps"] = 4
|
||||
extra_model_def["sliding_window"] = False
|
||||
extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"]
|
||||
extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]
|
||||
extra_model_def["tea_cache"] = True
|
||||
extra_model_def["mag_cache"] = True
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
return ["hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"]
|
||||
|
||||
@staticmethod
|
||||
def query_family_maps():
|
||||
models_eqv_map = {
|
||||
}
|
||||
|
||||
models_comp_map = {
|
||||
"hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"],
|
||||
}
|
||||
|
||||
return models_eqv_map, models_comp_map
|
||||
|
||||
@staticmethod
|
||||
def query_model_family():
|
||||
return "hunyuan"
|
||||
|
||||
@staticmethod
|
||||
def query_family_infos():
|
||||
return {"hunyuan":(20, "Hunyuan Video")}
|
||||
|
||||
@staticmethod
|
||||
def get_rgb_factors(base_model_type ):
|
||||
from shared.RGB_factors import get_rgb_factors
|
||||
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan")
|
||||
return latent_rgb_factors, latent_rgb_factors_bias
|
||||
|
||||
@staticmethod
|
||||
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
||||
text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization)
|
||||
return {
|
||||
"repoId" : "DeepBeepMeep/HunyuanVideo",
|
||||
"sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ],
|
||||
"fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) ,
|
||||
["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"],
|
||||
["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"],
|
||||
["detface.pt"],
|
||||
[ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename)
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
|
||||
from .hunyuan import HunyuanVideoSampler
|
||||
from mmgp import offload
|
||||
|
||||
hunyuan_model = HunyuanVideoSampler.from_pretrained(
|
||||
model_filepath = model_filename,
|
||||
model_type = model_type,
|
||||
base_model_type = base_model_type,
|
||||
text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization),
|
||||
dtype = dtype,
|
||||
quantizeTransformer = quantizeTransformer,
|
||||
VAE_dtype = VAE_dtype,
|
||||
mixed_precision_transformer = mixed_precision_transformer,
|
||||
save_quantized = save_quantized
|
||||
)
|
||||
|
||||
pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae }
|
||||
|
||||
if hunyuan_model.wav2vec != None:
|
||||
pipe["wav2vec"] = hunyuan_model.wav2vec
|
||||
|
||||
|
||||
# if hunyuan_model.align_instance != None:
|
||||
# pipe["align_instance"] = hunyuan_model.align_instance.facedet.model
|
||||
|
||||
|
||||
from .modules.models import get_linear_split_map
|
||||
|
||||
split_linear_modules_map = get_linear_split_map()
|
||||
hunyuan_model.model.split_linear_modules_map = split_linear_modules_map
|
||||
offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map )
|
||||
|
||||
|
||||
return hunyuan_model, pipe
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
ui_defaults["embedded_guidance_scale"]= 6.0
|
||||
|
||||
if base_model_type in ["hunyuan","hunyuan_i2v"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.0,
|
||||
})
|
||||
|
||||
elif base_model_type in ["hunyuan_custom"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
"flow_shift": 13,
|
||||
"resolution": "1280x720",
|
||||
"video_prompt_type": "I",
|
||||
})
|
||||
elif base_model_type in ["hunyuan_custom_audio"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
"flow_shift": 13,
|
||||
"video_prompt_type": "I",
|
||||
})
|
||||
elif base_model_type in ["hunyuan_custom_edit"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
"flow_shift": 13,
|
||||
"video_prompt_type": "MVAI",
|
||||
"sliding_window_size": 129,
|
||||
})
|
||||
elif base_model_type in ["hunyuan_avatar"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
"flow_shift": 5,
|
||||
"remove_background_images_ref": 0,
|
||||
"skip_steps_start_step_perc": 25,
|
||||
"video_length": 129,
|
||||
"video_prompt_type": "I",
|
||||
})
|
||||
@ -18,7 +18,7 @@ from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, appl
|
||||
from .token_refiner import SingleTokenRefiner
|
||||
import numpy as np
|
||||
from mmgp import offload
|
||||
from wan.modules.attention import pay_attention
|
||||
from shared.attention import pay_attention
|
||||
from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
|
||||
|
||||
def get_linear_split_map():
|
||||
@ -794,6 +794,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
block.disable_deterministic()
|
||||
|
||||
def compute_magcache_threshold(self, start_step, num_inference_steps = 0, speed_factor =0):
|
||||
skips_step_cache = self.cache
|
||||
|
||||
def nearest_interp(src_array, target_length):
|
||||
src_length = len(src_array)
|
||||
if target_length == 1:
|
||||
@ -801,11 +803,11 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
scale = (src_length - 1) / (target_length - 1)
|
||||
mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
|
||||
return src_array[mapped_indices]
|
||||
|
||||
if len(self.def_mag_ratios) != num_inference_steps:
|
||||
self.mag_ratios = nearest_interp(self.def_mag_ratios, num_inference_steps)
|
||||
def_mag_ratios = np.array([1.0]+ skips_step_cache.def_mag_ratios)
|
||||
if len(def_mag_ratios) != num_inference_steps:
|
||||
skips_step_cache.mag_ratios = nearest_interp(def_mag_ratios, num_inference_steps)
|
||||
else:
|
||||
self.mag_ratios = self.def_mag_ratios
|
||||
skips_step_cache.mag_ratios = def_mag_ratios
|
||||
|
||||
best_deltas = None
|
||||
best_threshold = 0.01
|
||||
@ -821,12 +823,12 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
if i<=start_step:
|
||||
skip = False
|
||||
else:
|
||||
cur_mag_ratio = self.mag_ratios[i] # conditional and unconditional in one list
|
||||
cur_mag_ratio = skips_step_cache.mag_ratios[i] # conditional and unconditional in one list
|
||||
accumulated_ratio *= cur_mag_ratio # magnitude ratio between current step and the cached step
|
||||
accumulated_steps += 1 # skip steps plus 1
|
||||
cur_skip_err = np.abs(1-accumulated_ratio) # skip error of current steps
|
||||
accumulated_err += cur_skip_err # accumulated error of multiple steps
|
||||
if accumulated_err<threshold and accumulated_steps<=self.magcache_K:
|
||||
if accumulated_err<threshold and accumulated_steps<=skips_step_cache.magcache_K:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
@ -842,7 +844,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
elif diff > best_diff:
|
||||
break
|
||||
threshold += 0.01
|
||||
self.magcache_thresh = best_threshold
|
||||
skips_step_cache.magcache_thresh = best_threshold
|
||||
print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{num_inference_steps/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
||||
return best_threshold
|
||||
|
||||
@ -969,23 +971,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
attn_mask = None
|
||||
|
||||
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
||||
|
||||
|
||||
if self.enable_cache:
|
||||
should_calc = True
|
||||
skip_steps_cache = self.cache
|
||||
if skip_steps_cache is not None:
|
||||
cache_type = skip_steps_cache.cache_type
|
||||
if x_id == 0:
|
||||
self.should_calc = True
|
||||
if self.enable_cache == "mag":
|
||||
if step_no > self.cache_start_step:
|
||||
cur_mag_ratio = self.mag_ratios[step_no]
|
||||
self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio
|
||||
cur_skip_err = np.abs(1-self.accumulated_ratio)
|
||||
self.accumulated_err += cur_skip_err
|
||||
self.accumulated_steps += 1
|
||||
if self.accumulated_err<=self.magcache_thresh and self.accumulated_steps<=self.magcache_K:
|
||||
self.should_calc = False
|
||||
self.cache_skipped_steps += 1
|
||||
skip_steps_cache.should_calc = True
|
||||
if cache_type == "mag":
|
||||
if step_no > skip_steps_cache.start_step:
|
||||
cur_mag_ratio = skip_steps_cache.mag_ratios[step_no]
|
||||
skip_steps_cache.accumulated_ratio = skip_steps_cache.accumulated_ratio*cur_mag_ratio
|
||||
cur_skip_err = np.abs(1-skip_steps_cache.accumulated_ratio)
|
||||
skip_steps_cache.accumulated_err += cur_skip_err
|
||||
skip_steps_cache.accumulated_steps += 1
|
||||
if skip_steps_cache.accumulated_err<=skip_steps_cache.magcache_thresh and skip_steps_cache.accumulated_steps<=skip_steps_cache.magcache_K:
|
||||
skip_steps_cache.should_calc = False
|
||||
skip_steps_cache.skipped_steps += 1
|
||||
else:
|
||||
self.accumulated_ratio, self.accumulated_steps, self.accumulated_err = 1.0, 0, 0
|
||||
skip_steps_cache.accumulated_ratio, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_err = 1.0, 0, 0
|
||||
else:
|
||||
inp = img[0:1]
|
||||
vec_ = vec[0:1]
|
||||
@ -994,26 +997,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
normed_inp = normed_inp.to(torch.bfloat16)
|
||||
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
|
||||
del normed_inp, img_mod1_shift, img_mod1_scale
|
||||
if step_no <= self.cache_start_step or step_no == self.num_steps-1:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
if step_no <= skip_steps_cache.start_step or step_no == skip_steps_cache.num_steps-1:
|
||||
skip_steps_cache.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
self.should_calc = False
|
||||
self.cache_skipped_steps += 1
|
||||
rescale_func = np.poly1d(skip_steps_cache.coefficients)
|
||||
skip_steps_cache.accumulated_rel_l1_distance += rescale_func(((modulated_inp-skip_steps_cache.previous_modulated_input).abs().mean() / skip_steps_cache.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if skip_steps_cache.accumulated_rel_l1_distance < skip_steps_cache.rel_l1_thresh:
|
||||
skip_steps_cache.should_calc = False
|
||||
skip_steps_cache.skipped_steps += 1
|
||||
else:
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
else:
|
||||
self.should_calc = True
|
||||
skip_steps_cache.accumulated_rel_l1_distance = 0
|
||||
skip_steps_cache.previous_modulated_input = modulated_inp
|
||||
should_calc = skip_steps_cache.should_calc
|
||||
|
||||
if not self.should_calc:
|
||||
img += self.previous_residual[x_id]
|
||||
if not should_calc:
|
||||
img += skip_steps_cache.previous_residual[x_id]
|
||||
else:
|
||||
if self.enable_cache:
|
||||
self.previous_residual[x_id] = None
|
||||
if skip_steps_cache is not None:
|
||||
skip_steps_cache.previous_residual[x_id] = None
|
||||
ori_img = img[0:1].clone()
|
||||
# --------------------- Pass through DiT blocks ------------------------
|
||||
for layer_num, block in enumerate(self.double_blocks):
|
||||
@ -1076,10 +1077,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
single_block_args = None
|
||||
|
||||
# img = x[:, :img_seq_len, ...]
|
||||
if self.enable_cache:
|
||||
if skip_steps_cache is not None:
|
||||
if len(img) > 1:
|
||||
self.previous_residual[0] = torch.empty_like(img)
|
||||
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
|
||||
skip_steps_cache.previous_residual[0] = torch.empty_like(img)
|
||||
for i, (x, residual) in enumerate(zip(img, skip_steps_cache.previous_residual[0])):
|
||||
if i < len(img) - 1:
|
||||
residual[...] = torch.sub(x, ori_img)
|
||||
else:
|
||||
@ -1087,8 +1088,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
||||
torch.sub(x, ori_img, out=residual)
|
||||
x = None
|
||||
else:
|
||||
self.previous_residual[x_id] = ori_img
|
||||
torch.sub(img, ori_img, out=self.previous_residual[x_id])
|
||||
skip_steps_cache.previous_residual[x_id] = ori_img
|
||||
torch.sub(img, ori_img, out=skip_steps_cache.previous_residual[x_id])
|
||||
|
||||
|
||||
if ref_length != None:
|
||||
@ -15,6 +15,7 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
|
||||
from ..constants import PRECISION_TO_TYPE
|
||||
from .llava.modeling_llava import LlavaForConditionalGeneration
|
||||
|
||||
|
||||
def use_default(value, default):
|
||||
@ -188,11 +189,17 @@ class TextEncoder(nn.Module):
|
||||
|
||||
if "llm" in text_encoder_type:
|
||||
from mmgp import offload
|
||||
forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
|
||||
self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath)
|
||||
if forcedConfigPath != None:
|
||||
# forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
|
||||
# self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath)
|
||||
|
||||
if "i2v" in text_encoder_type:
|
||||
self.model= offload.fast_load_transformers_model(self.model_path, modelClass= LlavaForConditionalGeneration)
|
||||
else:
|
||||
self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model", forcedConfigPath = "ckpts/llava-llama-3-8b/config.json")
|
||||
self.model.final_layer_norm = self.model.model.norm
|
||||
|
||||
|
||||
|
||||
else:
|
||||
self.model, self.model_path = load_text_encoder(
|
||||
text_encoder_type=self.text_encoder_type,
|
||||
29
models/hyvideo/text_encoder/llava/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# from ...utils import _LazyModule
|
||||
# from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from .configuration_llava import *
|
||||
# from .image_processing_llava_fast import *
|
||||
# from .modeling_llava import *
|
||||
# from .processing_llava import *
|
||||
# else:
|
||||
# import sys
|
||||
|
||||
# _file = globals()["__file__"]
|
||||
# sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
137
models/hyvideo/text_encoder/llava/configuration_llava.py
Normal file
@ -0,0 +1,137 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Llava model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
from transformers.models.auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
|
||||
Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Llava-9B.
|
||||
|
||||
e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_index (`int`, *optional*, defaults to 32000):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`.
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
|
||||
|
||||
>>> # Initializing a CLIP-vision config
|
||||
>>> vision_config = CLIPVisionConfig()
|
||||
|
||||
>>> # Initializing a Llama config
|
||||
>>> text_config = LlamaConfig()
|
||||
|
||||
>>> # Initializing a Llava llava-1.5-7b style configuration
|
||||
>>> configuration = LlavaConfig(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the llava-1.5-7b style configuration
|
||||
>>> model = LlavaForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "llava"
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_index=32000,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_seq_length=576,
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
"vision_feature_select_strategy should be one of 'default', 'full'."
|
||||
f"Got: {vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = (
|
||||
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
||||
)
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["clip_vision_model"](
|
||||
intermediate_size=4096,
|
||||
hidden_size=1024,
|
||||
patch_size=14,
|
||||
image_size=336,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
vocab_size=32000,
|
||||
projection_dim=768,
|
||||
)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["LlavaConfig"]
|
||||
436
models/hyvideo/text_encoder/llava/image_processing_llava.py
Normal file
@ -0,0 +1,436 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for LLaVa."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa image processor.
|
||||
|
||||
Args:
|
||||
do_pad (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pad the image to a square based on the longest edge.
|
||||
The padding value is determined by the `image_mean` parameter.
|
||||
Can be overridden by `do_pad` in the `preprocess` method.
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
||||
`preprocess` method.
|
||||
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
||||
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
||||
method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_pad: bool = False,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_center_crop: bool = True,
|
||||
crop_size: Dict[str, int] = None,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
||||
|
||||
self.do_pad = do_pad
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_pad",
|
||||
"do_resize",
|
||||
"size",
|
||||
"resample",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
def pad_to_square(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Pads an image to a square based on the longest edge.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
"""
|
||||
height, width = get_image_size(image, input_data_format)
|
||||
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
|
||||
|
||||
if height == width:
|
||||
image = (
|
||||
to_channel_dimension_format(image, data_format, input_data_format)
|
||||
if data_format is not None
|
||||
else image
|
||||
)
|
||||
return image
|
||||
|
||||
max_dim = max(height, width)
|
||||
|
||||
# Ensure background_color is the correct shape
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color]
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
|
||||
for i, color in enumerate(background_color):
|
||||
result[i, :, :] = color
|
||||
if width > height:
|
||||
start = (max_dim - height) // 2
|
||||
result[:, start : start + height, :] = image
|
||||
else:
|
||||
start = (max_dim - width) // 2
|
||||
result[:, :, start : start + width] = image
|
||||
else:
|
||||
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
|
||||
for i, color in enumerate(background_color):
|
||||
result[:, :, i] = color
|
||||
if width > height:
|
||||
start = (max_dim - height) // 2
|
||||
result[start : start + height, :, :] = image
|
||||
else:
|
||||
start = (max_dim - width) // 2
|
||||
result[:, start : start + width, :] = image
|
||||
|
||||
image = (
|
||||
to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||
resized to keep the input aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
default_to_square = True
|
||||
if "shortest_edge" in size:
|
||||
size = size["shortest_edge"]
|
||||
default_to_square = False
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||
|
||||
output_size = get_resize_output_image_size(
|
||||
image,
|
||||
size=size,
|
||||
default_to_square=default_to_square,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_pad: Optional[bool] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[int] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the image to a square based on the longest edge.
|
||||
The padding value is determined by the `image_mean` parameter.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||
Whether to center crop the image.
|
||||
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
# we don't pass `do_pad` here since LLaVa uses a custom padding to a square
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
processed_images = []
|
||||
for image in images:
|
||||
if do_pad:
|
||||
image = self.pad_to_square(
|
||||
image=image,
|
||||
background_color=tuple(int(x * 255) for x in self.image_mean),
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
processed_images.append(image)
|
||||
|
||||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["LlavaImageProcessor"]
|
||||
201
models/hyvideo/text_encoder/llava/image_processing_llava_fast.py
Normal file
@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for LLaVa."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
do_pad: Optional[bool]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast Llava image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
|
||||
""",
|
||||
)
|
||||
class LlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"shortest_edge": 224}
|
||||
default_to_square = False
|
||||
crop_size = {"height": 224, "width": 224}
|
||||
do_pad = False
|
||||
do_resize = True
|
||||
do_center_crop = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = LlavaFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def pad_to_square(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pads an image to a square based on the longest edge.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The images to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
Returns:
|
||||
`torch.Tensor`: The padded images.
|
||||
"""
|
||||
height, width = get_image_size(images, ChannelDimension.FIRST)
|
||||
|
||||
if height == width:
|
||||
return images
|
||||
|
||||
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color] + [0] * (num_channels - 1)
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
max_dim = max(height, width)
|
||||
paste_x_left = (max_dim - width) // 2
|
||||
paste_y_left = (max_dim - height) // 2
|
||||
paste_x_right = max_dim - width - paste_x_left
|
||||
paste_y_right = max_dim - height - paste_y_left
|
||||
padded_images = F.pad(
|
||||
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
|
||||
)
|
||||
|
||||
return padded_images
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_pad: bool,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_pad:
|
||||
stacked_images = self.pad_to_square(
|
||||
images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean)
|
||||
)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
padded_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
# Needed in case do_pad is False, or padding returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(padded_images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["LlavaImageProcessorFast"]
|
||||
531
models/hyvideo/text_encoder/llava/modeling_llava.py
Normal file
@ -0,0 +1,531 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Llava model."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_llava import LlavaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LlavaConfig"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlavaCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for Llava causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__()
|
||||
# We have hidden_size * the number of vision feature layers
|
||||
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size * num_feature_layers,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
LLAVA_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAVA_START_DOCSTRING,
|
||||
)
|
||||
class LlavaPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlavaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlavaVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Llava isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
LLAVA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
||||
[`CLIPImageProcessor`] for processing images).
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The LLAVA model which consists of a vision backbone and a language model.""",
|
||||
LLAVA_START_DOCSTRING,
|
||||
)
|
||||
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
|
||||
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
# For default; crop CLS from each hidden state in the hidden state pool
|
||||
if vision_feature_select_strategy == "default":
|
||||
hs_pool = [hs[:, 1:] for hs in hs_pool]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
# @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
# @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
||||
# @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
):
|
||||
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
||||
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return LlavaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"]
|
||||
203
models/hyvideo/text_encoder/llava/processing_llava.py
Normal file
@ -0,0 +1,203 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Llava.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
|
||||
|
||||
[`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
||||
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
num_additional_image_tokens (`int`, *optional*, defaults to 0):
|
||||
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
num_additional_image_tokens=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_additional_image_tokens = num_additional_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[LlavaProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify at least one of `images` or `text`.")
|
||||
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
LlavaProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
prompt_strings = text
|
||||
if image_inputs.get("pixel_values") is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
pixel_values = image_inputs["pixel_values"]
|
||||
height, width = get_image_size(to_numpy_array(pixel_values[0]))
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size
|
||||
) + self.num_additional_image_tokens
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["LlavaProcessor"]
|
||||
2
models/ltx_video/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .ltxv import LTXV
|
||||
from . import ltxv_handler
|
||||
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from diffusers.utils import logging
|
||||
from typing import Optional, List, Union
|
||||
import yaml
|
||||
from wan.utils.utils import calculate_new_dimensions
|
||||
from shared.utils.utils import calculate_new_dimensions
|
||||
import imageio
|
||||
import json
|
||||
import numpy as np
|
||||
@ -605,16 +605,4 @@ def load_media_file(
|
||||
raise Exception("video format not supported")
|
||||
return media_tensor
|
||||
|
||||
def query_model_def(model_type, model_def):
|
||||
LTXV_config = model_def.get("LTXV_config", "")
|
||||
distilled= "distilled" in LTXV_config
|
||||
model_def_output = {
|
||||
"no_guidance": True,
|
||||
}
|
||||
if distilled:
|
||||
model_def_output.update({
|
||||
"lock_inference_steps": True,
|
||||
"no_negative_prompt" : True,
|
||||
})
|
||||
|
||||
return model_def_output
|
||||
92
models/ltx_video/ltxv_handler.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_ltxv_text_encoder_filename(text_encoder_quantization):
|
||||
text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors"
|
||||
if text_encoder_quantization =="int8":
|
||||
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
|
||||
return text_encoder_filename
|
||||
|
||||
class family_handler():
|
||||
@staticmethod
|
||||
def query_model_def(base_model_type, model_def):
|
||||
LTXV_config = model_def.get("LTXV_config", "")
|
||||
distilled= "distilled" in LTXV_config
|
||||
extra_model_def = {
|
||||
"no_guidance": True,
|
||||
}
|
||||
if distilled:
|
||||
extra_model_def.update({
|
||||
"lock_inference_steps": True,
|
||||
"no_negative_prompt" : True,
|
||||
})
|
||||
|
||||
|
||||
extra_model_def["fps"] = 30
|
||||
extra_model_def["frames_minimum"] = 17
|
||||
extra_model_def["frames_steps"] = 8
|
||||
extra_model_def["sliding_window"] = True
|
||||
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
return ["ltxv_13B"]
|
||||
|
||||
@staticmethod
|
||||
def query_family_maps():
|
||||
return {}, {}
|
||||
|
||||
@staticmethod
|
||||
def get_rgb_factors(base_model_type ):
|
||||
from shared.RGB_factors import get_rgb_factors
|
||||
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv")
|
||||
return latent_rgb_factors, latent_rgb_factors_bias
|
||||
|
||||
@staticmethod
|
||||
def query_model_family():
|
||||
return "ltxv"
|
||||
|
||||
@staticmethod
|
||||
def query_family_infos():
|
||||
return {"ltxv":(10, "LTX Video")}
|
||||
|
||||
@staticmethod
|
||||
def get_vae_block_size(base_model_type):
|
||||
return 32
|
||||
|
||||
@staticmethod
|
||||
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
||||
text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization)
|
||||
return {
|
||||
"repoId" : "DeepBeepMeep/LTX_Video",
|
||||
"sourceFolderList" : ["T5_xxl_1.1", "" ],
|
||||
"fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ]
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
|
||||
from .ltxv import LTXV
|
||||
|
||||
ltxv_model = LTXV(
|
||||
model_filepath = model_filename,
|
||||
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
||||
model_type = model_type,
|
||||
base_model_type = base_model_type,
|
||||
model_def = model_def,
|
||||
dtype = dtype,
|
||||
# quantizeTransformer = quantizeTransformer,
|
||||
VAE_dtype = VAE_dtype,
|
||||
mixed_precision_transformer = mixed_precision_transformer
|
||||
)
|
||||
|
||||
pipeline = ltxv_model.pipeline
|
||||
pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler}
|
||||
|
||||
return ltxv_model, pipe
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
pass
|
||||
|
||||