diff --git a/.gitignore b/.gitignore index d95eb33..b42c7cd 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ *.pth *.ckpt *.safetensors -*.json +#*.json # *.txt *.backup *.pkl diff --git a/LICENSE.txt b/LICENSE.txt index ada4a22..1262c13 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,17 +1,46 @@ -FREE for Non Commercial USE +WanGP NON-COMMERCIAL EVALUATION LICENSE 1.0 -You are free to: -- Share — copy and redistribute the material in any medium or format -- Adapt — remix, transform, and build upon the material -The licensor cannot revoke these freedoms as long as you follow the license terms. +Definitions +1.1 “Software” means the source code, binaries, libraries, utilities and UI released under this license. +1.2 “Output” means images, videos or other media produced by running the Software. +1.3 “Commercial Use” means: +a) selling, sublicensing, renting, leasing, or otherwise distributing the Software, in whole or in part, for a fee or other consideration; or +b) offering the Software (or any derivative) as part of a paid product or hosted service; or +c) using the Software (or any derivative) to provide cloud-based or backend services, where end users access or pay for those services. -Under the following terms: -- Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use. -NonCommercial — You may not use the material for commercial purposes . +License Grant +Subject to Section 3: +a) You are granted a worldwide, non-exclusive, royalty-free, revocable license to use, reproduce, modify and distribute the Software for non-commercial purposes only. +b) You are granted a worldwide, non-exclusive, royalty-free, irrevocable license to use, reproduce, modify and distribute the Output for any purpose, including commercial sale, provided that any commercial distribution of the Output includes a clear notice that the Output was produced (in whole or in part) using WanGP, along with a hyperlink to the WanGP application’s About tab or repository. -- No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits. -Notices: +Restrictions +3.1 You MAY NOT distribute, sublicense or otherwise make available the Software (or any derivative) for Commercial Use. +3.2 You MAY sell, license or otherwise commercially exploit the Output without restriction. +3.3 If you wish to use the Software for Commercial Use, you must obtain a separate commercial license from the Licensor. -- You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation . +Third-Party Components 4.1 The Software includes components licensed under various open-source licenses (e.g., Apache 2.0, MIT, BSD). 4.2 You must comply with all applicable terms of those third-party licenses, including preservation of copyright notices, inclusion of required license texts, and patent-grant provisions. 4.3 You can find the full text of each third-party license via the “About” tab in the WanGP application, which provides links to their original GitHub repositories. + +Attribution +5.1 You must give appropriate credit by including: +• a copy of this license (or a link to it), and +• a notice that your use is based on “WanGP”. +5.2 You may do so in any reasonable manner, but not in any way that suggests the Licensor endorses you or your use. + +Disclaimer of Warranty & Liability +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE. + +Commercial Licensing The Licensor may offer commercial licenses for the Software, which grant rights to use the Software for Commercial Use. Please contact [deepbeepmeep@yahoo.com] for terms and pricing. + +Effective Date & Previous Versions +8.1 This license is effective as of the date the LICENSE file is updated in the WanGP repository. +8.2 Any copies of the Software obtained under prior license terms before this Effective Date remain governed by those prior terms; such granted rights are irrevocable. +8.3 Use of the Software after the release of any subsequent version by the Licensor is subject to the terms of the then-current license, unless a separate agreement is in place. + +Acceptable Use / Moral Clause +9.1 You MAY NOT use the Software or the Output to facilitate or produce content that is illegal, harmful, violent, harassing, defamatory, fraudulent, or otherwise violates applicable laws or fundamental human rights. +9.2 You MAY NOT deploy the Software or Output in contexts that promote hate speech, extremist ideology, human rights abuses, or other actions that could foreseeably cause significant harm to individuals or groups. +9.3 The Licensor reserves the right to terminate the rights granted under this license if a licensee materially breaches this Acceptable Use clause. + +END OF LICENSE -No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material. \ No newline at end of file diff --git a/README.md b/README.md index 0ba6bd6..df859e6 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,59 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : + +### August 12 2025: WanGP v7.7777 - Lucky Day(s) + +This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! + +Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. + +Support: +- Video: x264, x264 lossless, x265 +- Images: jpeg, png, webp, wbp lossless +Generation Settings are stored in each of the above regardless of the format (that was the hard part). + +Also you can now choose different output directories for images and videos. + +unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. +*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* +*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) + + +### August 10 2025: WanGP v7.76 - Faster than the VAE ... +We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... +Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. + +*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* + +### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 +Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. + +Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! + +### August 8 2025: WanGP v7.73 - Qwen Rebirth +Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. + +As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) + +Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. + +*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* + + +### August 6 2025: WanGP v7.71 - Picky, picky + +This release comes with two new models : +- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals +- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) + +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. + +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* + + ### August 4 2025: WanGP v7.6 - Remuxed With this new version you won't have any excuse if there is no sound in your video. @@ -180,7 +233,7 @@ git clone https://github.com/deepbeepmeep/Wan2GP.git cd Wan2GP conda create -n wan2gp python=3.10.9 conda activate wan2gp -pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124 +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 pip install -r requirements.txt ``` diff --git a/assets/comp_effic.png b/assets/comp_effic.png deleted file mode 100644 index ea0e3b2..0000000 Binary files a/assets/comp_effic.png and /dev/null differ diff --git a/assets/data_for_diff_stage.jpg b/assets/data_for_diff_stage.jpg deleted file mode 100644 index af98046..0000000 Binary files a/assets/data_for_diff_stage.jpg and /dev/null differ diff --git a/assets/i2v_res.png b/assets/i2v_res.png deleted file mode 100644 index fb13d61..0000000 Binary files a/assets/i2v_res.png and /dev/null differ diff --git a/assets/logo.png b/assets/logo.png deleted file mode 100644 index 0c55854..0000000 Binary files a/assets/logo.png and /dev/null differ diff --git a/assets/t2v_res.jpg b/assets/t2v_res.jpg deleted file mode 100644 index 6a58388..0000000 Binary files a/assets/t2v_res.jpg and /dev/null differ diff --git a/assets/vben_vs_sota.png b/assets/vben_vs_sota.png deleted file mode 100644 index 4f09de6..0000000 Binary files a/assets/vben_vs_sota.png and /dev/null differ diff --git a/assets/video_dit_arch.jpg b/assets/video_dit_arch.jpg deleted file mode 100644 index a13e499..0000000 Binary files a/assets/video_dit_arch.jpg and /dev/null differ diff --git a/assets/video_vae_res.jpg b/assets/video_vae_res.jpg deleted file mode 100644 index e1bfb11..0000000 Binary files a/assets/video_vae_res.jpg and /dev/null differ diff --git a/configs/i2v_2_2_multitalk.json b/configs/i2v_2_2_multitalk.json new file mode 100644 index 0000000..7206fdc --- /dev/null +++ b/configs/i2v_2_2_multitalk.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v2_2", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "multitalk_output_dim": 768 +} \ No newline at end of file diff --git a/configs/qwen_image_20B.json b/configs/qwen_image_20B.json new file mode 100644 index 0000000..4bff1e5 --- /dev/null +++ b/configs/qwen_image_20B.json @@ -0,0 +1,18 @@ +{ + "_class_name": "QwenImageTransformer2DModel", + "_diffusers_version": "0.34.0.dev0", + "attention_head_dim": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "guidance_embeds": false, + "in_channels": 64, + "joint_attention_dim": 3584, + "num_attention_heads": 24, + "num_layers": 60, + "out_channels": 16, + "patch_size": 2, + "pooled_projection_dim": 768 +} diff --git a/configs/ti2v_2_2.json b/configs/ti2v_2_2.json new file mode 100644 index 0000000..d58edcc --- /dev/null +++ b/configs/ti2v_2_2.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 48, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/defaults/fantasy.json b/defaults/fantasy.json index dbab1b2..99c38af 100644 --- a/defaults/fantasy.json +++ b/defaults/fantasy.json @@ -5,8 +5,7 @@ "architecture" : "fantasy", "modules": ["fantasy"], "description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.", - "URLs": "i2v_720p", - "teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + "URLs": "i2v_720p" }, "resolution": "1280x720" } diff --git a/defaults/i2v_2_2_multitalk.json b/defaults/i2v_2_2_multitalk.json new file mode 100644 index 0000000..9326469 --- /dev/null +++ b/defaults/i2v_2_2_multitalk.json @@ -0,0 +1,18 @@ +{ + "model": + { + "name": "Wan2.2 Multitalk 14B", + "architecture" : "i2v_2_2_multitalk", + "description": "The Multitalk module of Wan 2.1 has been combined with the Wan 2.2 image 2 video. It lets you have up to two people have a conversation.", + "modules": ["multitalk"], + "URLs": "i2v_2_2", + "URLs2": "i2v_2_2", + "group": "wan2_2", + "visible": false + }, + "switch_threshold" : 900, + "guidance_scale" : 3.5, + "guidance2_scale" : 3.5, + "flow_shift" : 5 + +} \ No newline at end of file diff --git a/defaults/ltxv_13B.json b/defaults/ltxv_13B.json index dc61e31..639442e 100644 --- a/defaults/ltxv_13B.json +++ b/defaults/ltxv_13B.json @@ -13,7 +13,7 @@ "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors", "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors" ], - "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-dev.yaml" + "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml" }, "num_inference_steps": 30 } diff --git a/defaults/ltxv_distilled.json b/defaults/ltxv_distilled.json index 8973b11..c570057 100644 --- a/defaults/ltxv_distilled.json +++ b/defaults/ltxv_distilled.json @@ -9,7 +9,7 @@ "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors" ], "preload_URLs" : "ltxv_13B", - "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml" + "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml" }, "num_inference_steps": 6 } diff --git a/defaults/qwen_image_20B.json b/defaults/qwen_image_20B.json new file mode 100644 index 0000000..691afee --- /dev/null +++ b/defaults/qwen_image_20B.json @@ -0,0 +1,21 @@ +{ + "model": { + "name": "Qwen Image 20B", + "architecture": "qwen_image_20B", + "description": "Qwen Image is generative model that will very high quality images. It is one of the few models capable to generate in the image very long texts.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_quanto_bf16_int8.safetensors" + ], + "resolutions": [ ["1328x1328 (1:1)", "1328x1328"], + ["1664x928 (16:9)", "1664x928"], + ["928x1664 (9:16)", "928x1664"], + ["1472x1140 (4:3)", "1472x1140"], + ["1140x1472 (3:4)", "1140x1472"]], + "attention": {"<89" : "sdpa"}, + "image_outputs": true + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/ti2v_2_2.json b/defaults/ti2v_2_2.json new file mode 100644 index 0000000..ac329fa --- /dev/null +++ b/defaults/ti2v_2_2.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Wan2.2 TextImage2video 5B", + "architecture": "ti2v_2_2", + "description": "Wan 2.2 Text 2 Video model 5B", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_quanto_mbf16_int8.safetensors" + ], + "group": "wan2_2" + }, + "video_length": 121, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 50, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json new file mode 100644 index 0000000..064c2b4 --- /dev/null +++ b/defaults/ti2v_2_2_fastwan.json @@ -0,0 +1,15 @@ +{ + "model": { + "name": "Wan2.2 FastWan TextImage2video 5B", + "architecture": "ti2v_2_2", + "description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution", + "URLs": "ti2v_2_2", + "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], + "group": "wan2_2" + }, + "video_length": 121, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 3, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index fa4c3a6..9f66422 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -8,9 +8,9 @@ This guide covers installation for different GPU generations and operating syste - Conda or Python venv - Compatible GPU (RTX 10XX or newer recommended) -## Installation for RTX 10XX to RTX 40XX (Stable) +## Installation for RTX 10XX to RTX 50XX (Stable) -This installation uses PyTorch 2.6.0 which is well-tested and stable. +This installation uses PyTorch 2.7.0 which is well-tested and stable. ### Step 1: Download and Setup Environment @@ -27,8 +27,8 @@ conda activate wan2gp ### Step 2: Install PyTorch ```shell -# Install PyTorch 2.6.0 with CUDA 12.4 -pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124 +# Install PyTorch 2.7.0 with CUDA 12.4 +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 ``` ### Step 3: Install Dependencies @@ -40,7 +40,7 @@ pip install -r requirements.txt ### Step 4: Optional Performance Optimizations -#### Sage Attention (30% faster) +#### Sage Attention (30% faster), don't install with RTX 50xx as it is not compatible ```shell # Windows only: Install Triton @@ -58,6 +58,7 @@ pip install triton-windows pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl # Linux (manual compilation required) +python -m pip install "setuptools<=75.8.2" --force-reinstall git clone https://github.com/thu-ml/SageAttention cd SageAttention pip install -e . @@ -70,61 +71,7 @@ pip install -e . pip install flash-attn==2.7.2.post1 ``` -## Installation for RTX 50XX (Beta) - -RTX 50XX GPUs require PyTorch 2.7.0 (beta). This version may be less stable. - -⚠️ **Important:** Use Python 3.10 for compatibility with pip wheels. - -### Step 1: Setup Environment - -```shell -# Clone and setup (same as above) -git clone https://github.com/deepbeepmeep/Wan2GP.git -cd Wan2GP -conda create -n wan2gp python=3.10.9 -conda activate wan2gp -``` - -### Step 2: Install PyTorch Beta - -```shell -# Install PyTorch 2.7.0 with CUDA 12.8 -pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 -``` - -### Step 3: Install Dependencies - -```shell -pip install -r requirements.txt -``` - -### Step 4: Optional Optimizations for RTX 50XX - -#### Sage Attention - -```shell -# Windows -pip install triton-windows -pip install sageattention==1.0.6 - -# Linux -pip install sageattention==1.0.6 -``` - -#### Sage 2 Attention - -```shell -# Windows -pip install triton-windows -pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl - -# Linux (manual compilation) -git clone https://github.com/thu-ml/SageAttention -cd SageAttention -pip install -e . -``` - + ## Attention Modes WanGP supports several attention implementations: @@ -134,6 +81,12 @@ WanGP supports several attention implementations: - **Sage2**: 40% speed boost - **Flash**: Good performance, may be complex to install on Windows +### Attention GPU Compatibility + +- RTX 10XX, 20XX: SDPA +- RTX 30XX, 40XX: SDPA, Flash Attention, Xformers, Sage, Sage2 +- RTX 50XX: SDPA, SDPA, Flash Attention, Xformers, Sage2 + ## Performance Profiles Choose a profile based on your hardware: @@ -161,10 +114,5 @@ If Sage attention doesn't work: - Use Profile 4 for lower VRAM usage - Consider using 1.3B models instead of 14B models -### GPU Compatibility -- RTX 10XX, 20XX: Supported with SDPA attention -- RTX 30XX, 40XX: Full feature support -- RTX 50XX: Beta support with PyTorch 2.7.0 - -For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) \ No newline at end of file +For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) diff --git a/docs/LORAS.md b/docs/LORAS.md index e20f7a3..73a88ac 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -65,11 +65,16 @@ For dynamic effects over generation steps, use comma-separated values: With models like Wan 2.2 that uses internally two diffusion models (*High noise* / *Low Noise*) you can specify which Loras you want to be applied for a specific phase by separating each phase with a ";". -For instance, if you want to disable a lora for phase *High Noise* and enablesit only for phase *Low Noise*: +For instance, if you want to disable a lora for phase *High Noise* and enables it only for phase *Low Noise*: ``` 0;1 ``` +Also with Wan 2.2, if you have two loras and you want the first one to be applied only during the High noise and the second one during the Low noise phase: +``` +1;0 0;1 +``` + As usual, you can use any float for of multiplier and have a multiplier varries throughout one phase for one Lora: ``` 0.9,0.8;1.2,1.1,1 diff --git a/flux/__init__.py b/flux/__init__.py deleted file mode 100644 index dddc6a3..0000000 --- a/flux/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -try: - from ._version import ( - version as __version__, # type: ignore - version_tuple, - ) -except ImportError: - __version__ = "unknown (no version information available)" - version_tuple = (0, 0, "unknown", "noinfo") - -from pathlib import Path - -PACKAGE = __package__.replace("_", "-") -PACKAGE_ROOT = Path(__file__).parent diff --git a/i2v_inference.py b/i2v_inference.py deleted file mode 100644 index f833868..0000000 --- a/i2v_inference.py +++ /dev/null @@ -1,682 +0,0 @@ -import os -import time -import argparse -import json -import torch -import traceback -import gc -import random - -# These imports rely on your existing code structure -# They must match the location of your WAN code, etc. -import wan -from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS -from wan.modules.attention import get_attention_modes -from wan.utils.utils import cache_video -from mmgp import offload, safetensors2, profile_type - -try: - import triton -except ImportError: - pass - -DATA_DIR = "ckpts" - -# -------------------------------------------------- -# HELPER FUNCTIONS -# -------------------------------------------------- - -def sanitize_file_name(file_name): - """Clean up file name from special chars.""" - return ( - file_name.replace("/", "") - .replace("\\", "") - .replace(":", "") - .replace("|", "") - .replace("?", "") - .replace("<", "") - .replace(">", "") - .replace('"', "") - ) - -def extract_preset(lset_name, lora_dir, loras): - """ - Load a .lset JSON that lists the LoRA files to apply, plus multipliers - and possibly a suggested prompt prefix. - """ - lset_name = sanitize_file_name(lset_name) - if not lset_name.endswith(".lset"): - lset_name_filename = os.path.join(lora_dir, lset_name + ".lset") - else: - lset_name_filename = os.path.join(lora_dir, lset_name) - - if not os.path.isfile(lset_name_filename): - raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}") - - with open(lset_name_filename, "r", encoding="utf-8") as reader: - text = reader.read() - lset = json.loads(text) - - loras_choices_files = lset["loras"] - loras_choices = [] - missing_loras = [] - for lora_file in loras_choices_files: - # Build absolute path and see if it is in loras - full_lora_path = os.path.join(lora_dir, lora_file) - if full_lora_path in loras: - idx = loras.index(full_lora_path) - loras_choices.append(str(idx)) - else: - missing_loras.append(lora_file) - - if len(missing_loras) > 0: - missing_list = ", ".join(missing_loras) - raise ValueError(f"Missing LoRA files for preset: {missing_list}") - - loras_mult_choices = lset["loras_mult"] - prompt_prefix = lset.get("prompt", "") - full_prompt = lset.get("full_prompt", False) - return loras_choices, loras_mult_choices, prompt_prefix, full_prompt - -def get_attention_mode(args_attention, installed_modes): - """ - Decide which attention mode to use: either the user choice or auto fallback. - """ - if args_attention == "auto": - for candidate in ["sage2", "sage", "sdpa"]: - if candidate in installed_modes: - return candidate - return "sdpa" # last fallback - elif args_attention in installed_modes: - return args_attention - else: - raise ValueError( - f"Requested attention mode '{args_attention}' not installed. " - f"Installed modes: {installed_modes}" - ) - -def load_i2v_model(model_filename, text_encoder_filename, is_720p): - """ - Load the i2v model with a specific size config and text encoder. - """ - if is_720p: - print("Loading 14B-720p i2v model ...") - cfg = WAN_CONFIGS['i2v-14B'] - wan_model = wan.WanI2V( - config=cfg, - checkpoint_dir=DATA_DIR, - model_filename=model_filename, - text_encoder_filename=text_encoder_filename - ) - else: - print("Loading 14B-480p i2v model ...") - cfg = WAN_CONFIGS['i2v-14B'] - wan_model = wan.WanI2V( - config=cfg, - checkpoint_dir=DATA_DIR, - model_filename=model_filename, - text_encoder_filename=text_encoder_filename - ) - # Pipe structure - pipe = { - "transformer": wan_model.model, - "text_encoder": wan_model.text_encoder.model, - "text_encoder_2": wan_model.clip.model, - "vae": wan_model.vae.model - } - return wan_model, pipe - -def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps): - """ - Load loras from a directory, optionally apply a preset. - """ - from pathlib import Path - import glob - - if not lora_dir or not Path(lora_dir).is_dir(): - print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.") - return [], [], [], "", "", False - - # Gather LoRA files - loras = sorted( - glob.glob(os.path.join(lora_dir, "*.sft")) - + glob.glob(os.path.join(lora_dir, "*.safetensors")) - ) - loras_names = [Path(x).stem for x in loras] - - # Offload them with no activation - offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False) - - # If user gave a preset, apply it - default_loras_choices = [] - default_loras_multis_str = "" - default_prompt_prefix = "" - preset_applied_full_prompt = False - if lora_preset: - loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras) - default_loras_choices = loras_choices - # If user stored loras_mult as a list or string in JSON, unify that to str - if isinstance(loras_mult, list): - # Just store them in a single line - default_loras_multis_str = " ".join([str(x) for x in loras_mult]) - else: - default_loras_multis_str = str(loras_mult) - default_prompt_prefix = prefix - preset_applied_full_prompt = full_prompt - - return ( - loras, - loras_names, - default_loras_choices, - default_loras_multis_str, - default_prompt_prefix, - preset_applied_full_prompt - ) - -def parse_loras_and_activate( - transformer, - loras, - loras_choices, - loras_mult_str, - num_inference_steps -): - """ - Activate the chosen LoRAs with multipliers over the pipeline's transformer. - Supports stepwise expansions (like "0.5,0.8" for partial steps). - """ - if not loras or not loras_choices: - # no LoRAs selected - return - - # Handle multipliers - def is_float_or_comma_list(x): - """ - Example: "0.5", or "0.8,1.0", etc. is valid. - """ - if not x: - return False - for chunk in x.split(","): - try: - float(chunk.strip()) - except ValueError: - return False - return True - - # Convert multiline or spaced lines to a single list - lines = [ - line.strip() - for line in loras_mult_str.replace("\r", "\n").split("\n") - if line.strip() and not line.strip().startswith("#") - ] - # Now combine them by space - joined_line = " ".join(lines) # "1.0 2.0,3.0" - if not joined_line.strip(): - multipliers = [] - else: - multipliers = joined_line.split(" ") - - # Expand each item - final_multipliers = [] - for mult in multipliers: - mult = mult.strip() - if not mult: - continue - if is_float_or_comma_list(mult): - # Could be "0.7" or "0.5,0.6" - if "," in mult: - # expand over steps - chunk_vals = [float(x.strip()) for x in mult.split(",")] - expanded = expand_list_over_steps(chunk_vals, num_inference_steps) - final_multipliers.append(expanded) - else: - final_multipliers.append(float(mult)) - else: - raise ValueError(f"Invalid LoRA multiplier: '{mult}'") - - # If fewer multipliers than chosen LoRAs => pad with 1.0 - needed = len(loras_choices) - len(final_multipliers) - if needed > 0: - final_multipliers += [1.0]*needed - - # Actually activate them - offload.activate_loras(transformer, loras_choices, final_multipliers) - -def expand_list_over_steps(short_list, num_steps): - """ - If user gave (0.5, 0.8) for example, expand them over `num_steps`. - The expansion is simply linear slice across steps. - """ - result = [] - inc = len(short_list) / float(num_steps) - idxf = 0.0 - for _ in range(num_steps): - value = short_list[int(idxf)] - result.append(value) - idxf += inc - return result - -def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR): - """ - Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'. - If not, downloads them from a Hugging Face Hub repo. - Adjust the 'repo_id' and needed files as appropriate. - """ - import os - from pathlib import Path - - try: - from huggingface_hub import hf_hub_download, snapshot_download - except ImportError as e: - raise ImportError( - "huggingface_hub is required for automatic model download. " - "Please install it via `pip install huggingface_hub`." - ) from e - - # Identify just the filename portion for each path - def basename(path_str): - return os.path.basename(path_str) - - repo_id = "DeepBeepMeep/Wan2.1" - target_root = local_folder - - # You can customize this list as needed for i2v usage. - # At minimum you need: - # 1) The requested i2v transformer file - # 2) The requested text encoder file - # 3) VAE file - # 4) The open-clip xlm-roberta-large weights - # - # If your i2v config references additional files, add them here. - needed_files = [ - "Wan2.1_VAE.pth", - "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - basename(text_encoder_filename), - basename(transformer_filename_i2v), - ] - - # The original script also downloads an entire "xlm-roberta-large" folder - # via snapshot_download. If you require that for your pipeline, - # you can add it here, for example: - subfolder_name = "xlm-roberta-large" - if not Path(os.path.join(target_root, subfolder_name)).exists(): - snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root) - - for filename in needed_files: - local_path = os.path.join(target_root, filename) - if not os.path.isfile(local_path): - print(f"File '{filename}' not found locally. Downloading from {repo_id} ...") - hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir=target_root - ) - else: - # Already present - pass - - print("All required i2v files are present.") - - -# -------------------------------------------------- -# ARGUMENT PARSER -# -------------------------------------------------- - -def parse_args(): - parser = argparse.ArgumentParser( - description="Image-to-Video inference using WAN 2.1 i2v" - ) - # Model + Tools - parser.add_argument( - "--quantize-transformer", - action="store_true", - help="Use on-the-fly transformer quantization" - ) - parser.add_argument( - "--compile", - action="store_true", - help="Enable PyTorch 2.0 compile for the transformer" - ) - parser.add_argument( - "--attention", - type=str, - default="auto", - help="Which attention to use: auto, sdpa, sage, sage2, flash" - ) - parser.add_argument( - "--profile", - type=int, - default=4, - help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM" - ) - parser.add_argument( - "--preload", - type=int, - default=0, - help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)" - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="Verbosity level [0..5]" - ) - - # i2v Model - parser.add_argument( - "--transformer-file", - type=str, - default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors", - help="Which i2v model to load" - ) - parser.add_argument( - "--text-encoder-file", - type=str, - default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors", - help="Which text encoder to use" - ) - - # LoRA - parser.add_argument( - "--lora-dir", - type=str, - default="", - help="Path to a directory containing i2v LoRAs" - ) - parser.add_argument( - "--lora-preset", - type=str, - default="", - help="A .lset preset name in the lora_dir to auto-apply" - ) - - # Generation Options - parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation") - parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt") - parser.add_argument("--resolution", type=str, default="832x480", help="WxH") - parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.") - parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.") - parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale") - parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.") - parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos") - parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.") - parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]") - parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.") - parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance") - parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG") - parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG") - - # LoRA usage - parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.") - parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.") - - # Input - parser.add_argument( - "--input-image", - type=str, - default=None, - required=True, - help="Path to an input image (or multiple)." - ) - parser.add_argument( - "--output-file", - type=str, - default="output.mp4", - help="Where to save the resulting video." - ) - - return parser.parse_args() - -# -------------------------------------------------- -# MAIN -# -------------------------------------------------- - -def main(): - args = parse_args() - - # Setup environment - offload.default_verboseLevel = args.verbose - installed_attn_modes = get_attention_modes() - - # Decide attention - chosen_attention = get_attention_mode(args.attention, installed_attn_modes) - offload.shared_state["_attention"] = chosen_attention - - # Determine i2v resolution format - if "720" in args.transformer_file: - is_720p = True - else: - is_720p = False - - # Make sure we have the needed models locally - download_models_if_needed(args.transformer_file, args.text_encoder_file) - - # Load i2v - wan_model, pipe = load_i2v_model( - model_filename=args.transformer_file, - text_encoder_filename=args.text_encoder_file, - is_720p=is_720p - ) - wan_model._interrupt = False - - # Offload / profile - # e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...) - # pass the budgets if you want, etc. - kwargs = {} - if args.profile == 2 or args.profile == 4: - # preload is in MB - if args.preload == 0: - budgets = {"transformer": 100, "text_encoder": 100, "*": 1000} - else: - budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000} - kwargs["budgets"] = budgets - elif args.profile == 3: - kwargs["budgets"] = {"*": "70%"} - - compile_choice = "transformer" if args.compile else "" - # Create the offload object - offloadobj = offload.profile( - pipe, - profile_no=args.profile, - compile=compile_choice, - quantizeTransformer=args.quantize_transformer, - **kwargs - ) - - # If user wants to use LoRAs - ( - loras, - loras_names, - default_loras_choices, - default_loras_multis_str, - preset_prompt_prefix, - preset_full_prompt - ) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps) - - # Combine user prompt with preset prompt if the preset indicates so - if preset_prompt_prefix: - if preset_full_prompt: - # Full override - user_prompt = preset_prompt_prefix - else: - # Just prefix - user_prompt = preset_prompt_prefix + "\n" + args.prompt - else: - user_prompt = args.prompt - - # Actually parse user LoRA choices if they did not rely purely on the preset - if args.loras_choices: - # If user gave e.g. "0,1", we treat that as new additions - lora_choice_list = [x.strip() for x in args.loras_choices.split(",")] - else: - # Use the defaults from the preset - lora_choice_list = default_loras_choices - - # Activate them - parse_loras_and_activate( - pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps - ) - - # Negative prompt - negative_prompt = args.negative_prompt or "" - - # Sanity check resolution - if "*" in args.resolution.lower(): - print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.") - resolution_str = args.resolution.lower().replace("*", "x") - else: - resolution_str = args.resolution - - try: - width, height = [int(x) for x in resolution_str.split("x")] - except: - raise ValueError(f"Invalid resolution: '{resolution_str}'") - - # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided) - if args.slg_layers: - slg_list = [int(x) for x in args.slg_layers.split(",")] - else: - slg_list = None - - # Additional checks (from your original code). - if "480p" in args.transformer_file: - # Then we cannot exceed certain area for 480p model - if width * height > 832*480: - raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.") - # etc. - - # Handle random seed - if args.seed < 0: - args.seed = random.randint(0, 999999999) - print(f"Using seed={args.seed}") - - # Setup tea cache if needed - trans = wan_model.model - trans.enable_cache = (args.teacache > 0) - if trans.enable_cache: - if "480p" in args.transformer_file: - # example from your code - trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - elif "720p" in args.transformer_file: - trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] - else: - raise ValueError("Teacache not supported for this model variant") - - # Attempt generation - print("Starting generation ...") - start_time = time.time() - - # Read the input image - if not os.path.isfile(args.input_image): - raise ValueError(f"Input image does not exist: {args.input_image}") - - from PIL import Image - input_img = Image.open(args.input_image).convert("RGB") - - # Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration - - # Define the generation call - # - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ... - # You can correct to that if needed: - frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1 - # RIFLEx - enable_riflex = args.riflex - - # If teacache => reset counters - if trans.enable_cache: - trans.teacache_counter = 0 - trans.cache_multiplier = args.teacache - trans.cache_start_step = int(args.teacache_start * args.steps / 100.0) - trans.num_steps = args.steps - trans.cache_skipped_steps = 0 - trans.previous_residual_uncond = None - trans.previous_residual_cond = None - - # VAE Tiling - device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 - if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM - use_vae_config = 1 - elif device_mem_capacity >= 8000: - use_vae_config = 2 - else: - use_vae_config = 3 - - if use_vae_config == 1: - VAE_tile_size = 0 - elif use_vae_config == 2: - VAE_tile_size = 256 - else: - VAE_tile_size = 128 - - print('Using VAE tile size of', VAE_tile_size) - - # Actually run the i2v generation - try: - sample_frames = wan_model.generate( - input_prompt = user_prompt, - image_start = input_img, - frame_num=frame_count, - width=width, - height=height, - # max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom - shift=args.flow_shift, - sampling_steps=args.steps, - guide_scale=args.guidance_scale, - n_prompt=negative_prompt, - seed=args.seed, - offload_model=False, - callback=None, # or define your own callback if you want - enable_RIFLEx=enable_riflex, - VAE_tile_size=VAE_tile_size, - joint_pass=slg_list is None, # set if you want a small speed improvement without SLG - slg_layers=slg_list, - slg_start=args.slg_start, - slg_end=args.slg_end, - ) - except Exception as e: - offloadobj.unload_all() - gc.collect() - torch.cuda.empty_cache() - - err_str = f"Generation failed with error: {e}" - # Attempt to detect OOM errors - s = str(e).lower() - if any(keyword in s for keyword in ["memory", "cuda", "alloc"]): - raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str) - else: - traceback.print_exc() - raise RuntimeError(err_str) - - # After generation - offloadobj.unload_all() - gc.collect() - torch.cuda.empty_cache() - - if sample_frames is None: - raise RuntimeError("No frames were returned (maybe generation was aborted or failed).") - - # If teacache was used, we can see how many steps were skipped - if trans.enable_cache: - print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}") - - # Save result - sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W] - os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) - - # Use the provided helper from your code to store the MP4 - # By default, you used cache_video(tensor=..., save_file=..., fps=16, ...) - # or you can do your own. We'll do the same for consistency: - cache_video( - tensor=sample_frames[None], # shape => [1, c, T, H, W] - save_file=args.output_file, - fps=16, - nrow=1, - normalize=True, - value_range=(-1, 1) - ) - - end_time = time.time() - elapsed_s = end_time - start_time - print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.") - -if __name__ == "__main__": - main() diff --git a/loras_qwen/Readme.txt b/loras_qwen/Readme.txt new file mode 100644 index 0000000..14a70a8 --- /dev/null +++ b/loras_qwen/Readme.txt @@ -0,0 +1 @@ +LTX Video loras \ No newline at end of file diff --git a/hyvideo/__init__.py b/models/__init__.py similarity index 100% rename from hyvideo/__init__.py rename to models/__init__.py diff --git a/models/flux/__init__.py b/models/flux/__init__.py new file mode 100644 index 0000000..d0a07ae --- /dev/null +++ b/models/flux/__init__.py @@ -0,0 +1,2 @@ +from .flux_main import model_factory +from . import flux_handler diff --git a/flux/__main__.py b/models/flux/__main__.py similarity index 100% rename from flux/__main__.py rename to models/flux/__main__.py diff --git a/flux/_version.py b/models/flux/_version.py similarity index 100% rename from flux/_version.py rename to models/flux/_version.py diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py new file mode 100644 index 0000000..9bdc0cf --- /dev/null +++ b/models/flux/flux_handler.py @@ -0,0 +1,103 @@ +import torch + +def get_ltxv_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + flux_model = model_def.get("flux-model", "flux-dev") + flux_schnell = flux_model == "flux-schnell" + model_def_output = { + "image_outputs" : True, + "no_negative_prompt" : True, + } + if flux_schnell: + model_def_output["no_guidance"] = True + else: + model_def_output["embedded_guidance"] = True + + + return model_def_output + + @staticmethod + def query_supported_types(): + return ["flux"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux") + return latent_rgb_factors, latent_rgb_factors_bias + + + @staticmethod + def query_model_family(): + return "flux" + + @staticmethod + def query_family_infos(): + return {"flux":(30, "Flux 1")} + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) + return [ + { + "repoId" : "DeepBeepMeep/Flux", + "sourceFolderList" : [""], + "fileList" : [ ["flux_vae.safetensors"] ] + }, + { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : ["T5_xxl_1.1"], + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ] + }, + { + "repoId" : "DeepBeepMeep/HunyuanVideo", + "sourceFolderList" : [ "clip_vit_large_patch14", ], + "fileList" :[ + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ] + } + ] + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .flux_main import model_factory + + flux_model = model_factory( + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5} + + return flux_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "embedded_guidance": 2.5, + }) + if model_def.get("reference_image", False): + ui_defaults.update({ + "video_prompt_type": "KI", + }) + diff --git a/flux/flux_main.py b/models/flux/flux_main.py similarity index 90% rename from flux/flux_main.py rename to models/flux/flux_main.py index 303765a..16659fc 100644 --- a/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -5,10 +5,10 @@ from dataclasses import dataclass from glob import iglob from mmgp import offload as offload import torch -from wan.utils.utils import calculate_new_dimensions -from flux.sampling import denoise, get_schedule, prepare_kontext, unpack -from flux.modules.layers import get_linear_split_map -from flux.util import ( +from shared.utils.utils import calculate_new_dimensions +from .sampling import denoise, get_schedule, prepare_kontext, unpack +from .modules.layers import get_linear_split_map +from .util import ( aspect_ratio_to_height_width, load_ae, load_clip, @@ -146,13 +146,3 @@ class model_factory: x = x.transpose(0, 1) return x -def query_model_def(model_type, model_def): - flux_model = model_def.get("flux-model", "flux-dev") - flux_schnell = flux_model == "flux-schnell" - model_def_output = { - "image_outputs" : True, - } - if flux_schnell: - model_def_output["no_guidance"] = True - - return model_def_output \ No newline at end of file diff --git a/flux/math.py b/models/flux/math.py similarity index 97% rename from flux/math.py rename to models/flux/math.py index 9e8aa59..a249f19 100644 --- a/flux/math.py +++ b/models/flux/math.py @@ -1,7 +1,7 @@ import torch from einops import rearrange from torch import Tensor -from wan.modules.attention import pay_attention +from shared.attention import pay_attention def attention(qkv_list, pe: Tensor) -> Tensor: diff --git a/flux/model.py b/models/flux/model.py similarity index 98% rename from flux/model.py rename to models/flux/model.py index d6c1b6c..d84ceb3 100644 --- a/flux/model.py +++ b/models/flux/model.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from flux.modules.layers import ( +from .modules.layers import ( DoubleStreamBlock, EmbedND, LastLayer, @@ -11,7 +11,7 @@ from flux.modules.layers import ( SingleStreamBlock, timestep_embedding, ) -from flux.modules.lora import LinearLora, replace_linear_with_lora +from .modules.lora import LinearLora, replace_linear_with_lora @dataclass diff --git a/flux/modules/autoencoder.py b/models/flux/modules/autoencoder.py similarity index 100% rename from flux/modules/autoencoder.py rename to models/flux/modules/autoencoder.py diff --git a/flux/modules/conditioner.py b/models/flux/modules/conditioner.py similarity index 100% rename from flux/modules/conditioner.py rename to models/flux/modules/conditioner.py diff --git a/flux/modules/image_embedders.py b/models/flux/modules/image_embedders.py similarity index 98% rename from flux/modules/image_embedders.py rename to models/flux/modules/image_embedders.py index aa26d9b..011f840 100644 --- a/flux/modules/image_embedders.py +++ b/models/flux/modules/image_embedders.py @@ -7,7 +7,7 @@ from safetensors.torch import load_file as load_sft from torch import nn from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel -from flux.util import print_load_warning +from ..util import print_load_warning class DepthImageEncoder: diff --git a/flux/modules/layers copy.py b/models/flux/modules/layers copy.py similarity index 100% rename from flux/modules/layers copy.py rename to models/flux/modules/layers copy.py diff --git a/flux/modules/layers.py b/models/flux/modules/layers.py similarity index 99% rename from flux/modules/layers.py rename to models/flux/modules/layers.py index d0273b7..b937143 100644 --- a/flux/modules/layers.py +++ b/models/flux/modules/layers.py @@ -5,7 +5,7 @@ import torch from einops import rearrange from torch import Tensor, nn -from flux.math import attention, rope +from ..math import attention, rope def get_linear_split_map(): hidden_size = 3072 diff --git a/flux/modules/lora.py b/models/flux/modules/lora.py similarity index 100% rename from flux/modules/lora.py rename to models/flux/modules/lora.py diff --git a/flux/sampling.py b/models/flux/sampling.py similarity index 99% rename from flux/sampling.py rename to models/flux/sampling.py index 5a15c5e..23cfcf3 100644 --- a/flux/sampling.py +++ b/models/flux/sampling.py @@ -343,7 +343,7 @@ def denoise( updated_num_steps= len(timesteps) -1 if callback != None: - from wan.utils.loras_mutipliers import update_loras_slists + from shared.utils.loras_mutipliers import update_loras_slists update_loras_slists(model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload diff --git a/flux/util.py b/models/flux/util.py similarity index 99% rename from flux/util.py rename to models/flux/util.py index a9815af..a729dd7 100644 --- a/flux/util.py +++ b/models/flux/util.py @@ -11,9 +11,9 @@ from huggingface_hub import hf_hub_download, login from PIL import ExifTags, Image from safetensors.torch import load_file as load_sft -from flux.model import Flux, FluxLoraWrapper, FluxParams -from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams -from flux.modules.conditioner import HFEmbedder +from .model import Flux, FluxLoraWrapper, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder CHECKPOINTS_DIR = Path("checkpoints") diff --git a/models/hyvideo/__init__.py b/models/hyvideo/__init__.py new file mode 100644 index 0000000..d3a3700 --- /dev/null +++ b/models/hyvideo/__init__.py @@ -0,0 +1,2 @@ +from .hunyuan import HunyuanVideoSampler +from . import hunyuan_handler \ No newline at end of file diff --git a/hyvideo/config.py b/models/hyvideo/config.py similarity index 100% rename from hyvideo/config.py rename to models/hyvideo/config.py diff --git a/hyvideo/constants.py b/models/hyvideo/constants.py similarity index 100% rename from hyvideo/constants.py rename to models/hyvideo/constants.py diff --git a/hyvideo/data_kits/audio_dataset.py b/models/hyvideo/data_kits/audio_dataset.py similarity index 100% rename from hyvideo/data_kits/audio_dataset.py rename to models/hyvideo/data_kits/audio_dataset.py diff --git a/hyvideo/data_kits/audio_preprocessor.py b/models/hyvideo/data_kits/audio_preprocessor.py similarity index 100% rename from hyvideo/data_kits/audio_preprocessor.py rename to models/hyvideo/data_kits/audio_preprocessor.py diff --git a/hyvideo/data_kits/data_tools.py b/models/hyvideo/data_kits/data_tools.py similarity index 100% rename from hyvideo/data_kits/data_tools.py rename to models/hyvideo/data_kits/data_tools.py diff --git a/hyvideo/data_kits/face_align/__init__.py b/models/hyvideo/data_kits/face_align/__init__.py similarity index 100% rename from hyvideo/data_kits/face_align/__init__.py rename to models/hyvideo/data_kits/face_align/__init__.py diff --git a/hyvideo/data_kits/face_align/align.py b/models/hyvideo/data_kits/face_align/align.py similarity index 100% rename from hyvideo/data_kits/face_align/align.py rename to models/hyvideo/data_kits/face_align/align.py diff --git a/hyvideo/data_kits/face_align/detface.py b/models/hyvideo/data_kits/face_align/detface.py similarity index 100% rename from hyvideo/data_kits/face_align/detface.py rename to models/hyvideo/data_kits/face_align/detface.py diff --git a/hyvideo/diffusion/__init__.py b/models/hyvideo/diffusion/__init__.py similarity index 100% rename from hyvideo/diffusion/__init__.py rename to models/hyvideo/diffusion/__init__.py diff --git a/hyvideo/diffusion/pipelines/__init__.py b/models/hyvideo/diffusion/pipelines/__init__.py similarity index 100% rename from hyvideo/diffusion/pipelines/__init__.py rename to models/hyvideo/diffusion/pipelines/__init__.py diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py similarity index 98% rename from hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py rename to models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py index 22f652e..ed91f9c 100644 --- a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py @@ -949,15 +949,18 @@ class HunyuanVideoPipeline(DiffusionPipeline): # width = width or self.transformer.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks trans = self.transformer - if trans.enable_cache == "tea": - teacache_multiplier = trans.cache_multiplier - trans.accumulated_rel_l1_distance = 0 - trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 - elif trans.enable_cache == "mag": - trans.compute_magcache_threshold(trans.cache_start_step, num_inference_steps, trans.cache_multiplier) - trans.accumulated_err, trans.accumulated_steps, trans.accumulated_ratio = 0, 0, 1.0 - else: - trans.enable_cache == None + skip_steps_cache = trans.cache + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type + if cache_type == "tea": + teacache_multiplier = skip_steps_cache.multiplier + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 + elif cache_type== "mag": + trans.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0 + else: + trans.cache = None # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -1212,8 +1215,8 @@ class HunyuanVideoPipeline(DiffusionPipeline): if ip_cfg_scale>0: latent_items += 1 - if self.transformer.enable_cache: - self.transformer.previous_residual = [None] * latent_items + if skip_steps_cache != None: + skip_steps_cache.previous_residual = [None] * latent_items # if is_progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py similarity index 95% rename from hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py rename to models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py index 191f9ab..c043a12 100644 --- a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py +++ b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py @@ -41,9 +41,9 @@ from diffusers.utils import ( from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from hyvideo.constants import PRECISION_TO_TYPE -from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from hyvideo.text_encoder import TextEncoder +from ...constants import PRECISION_TO_TYPE +from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from ...text_encoder import TextEncoder from einops import rearrange from ...modules import HYVideoDiffusionTransformer @@ -934,15 +934,20 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): transformer = self.transformer - if transformer.enable_cache == "tea": - teacache_multiplier = transformer.cache_multiplier - transformer.accumulated_rel_l1_distance = 0 - transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 - elif transformer.enable_cache == "mag": - transformer.compute_magcache_threshold(transformer.cache_start_step, num_inference_steps, transformer.cache_multiplier) - transformer.accumulated_err, transformer.accumulated_steps, transformer.accumulated_ratio = 0, 0, 1.0 - else: - transformer.enable_cache == None + skip_steps_cache = transformer.cache + cache_type = None + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type + if cache_type == "tea": + teacache_multiplier = skip_steps_cache.multiplier + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 + elif cache_type == "mag": + transformer.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0 + else: + transformer.cache = None + cache_type = None # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1141,16 +1146,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): if self._interrupt: return [None] - if transformer.enable_cache == "tea": + if cache_type == "tea": cache_size = round( infer_length / frames_per_batch ) - transformer.previous_residual = [None] * latent_items + skip_steps_cache.previous_residual = [None] * latent_items cache_all_previous_residual = [None] * latent_items cache_all_previous_modulated_input = None cache_should_calc = [True] * cache_size cache_accumulated_rel_l1_distance = [0.] * cache_size cache_teacache_skipped_steps = [0] * cache_size - elif transformer.enable_cache == "mag": - transformer.previous_residual = [None] * latent_items + elif cache_type == "mag": + skip_steps_cache.previous_residual = [None] * latent_items with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1187,16 +1192,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1) img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3] - if transformer.enable_cache == "tea" and cache_size > 1: + if cache_type == "tea" and cache_size > 1: for l in range(latent_items): if cache_all_previous_residual[l] != None: bsz = cache_all_previous_residual[l].shape[0] - transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + skip_steps_cache.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) if cache_all_previous_modulated_input != None: - transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) - transformer.should_calc = cache_should_calc[cache_slot_no] - transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] - transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] + skip_steps_cache.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + skip_steps_cache.should_calc = cache_should_calc[cache_slot_no] + skip_steps_cache.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] + skip_steps_cache.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] if self.do_classifier_free_guidance: @@ -1304,21 +1309,21 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): pred_latents[:, :, p] += latents[:, :, iii] counter[:, :, p] += 1 - if transformer.enable_cache == "tea" and cache_size > 1: + if cache_type == "tea" and cache_size > 1: for l in range(latent_items): - if transformer.previous_residual[l] != None: - bsz = transformer.previous_residual[l].shape[0] + if skip_steps_cache.previous_residual[l] != None: + bsz = skip_steps_cache.previous_residual[l].shape[0] if cache_all_previous_residual[l] == None: - cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype) - cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) + cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=skip_steps_cache.previous_residual[l].device, dtype=skip_steps_cache.previous_residual[l].dtype) + cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) - if transformer.previous_modulated_input != None: + if skip_steps_cache.previous_modulated_input != None: if cache_all_previous_modulated_input == None: - cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype) - cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) - cache_should_calc[cache_slot_no] = transformer.should_calc - cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance - cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps + cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=skip_steps_cache.previous_modulated_input.device, dtype=skip_steps_cache.previous_modulated_input.dtype) + cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) + cache_should_calc[cache_slot_no] = skip_steps_cache.should_calc + cache_accumulated_rel_l1_distance[cache_slot_no] = skip_steps_cache.accumulated_rel_l1_distance + cache_teacache_skipped_steps[cache_slot_no] = skip_steps_cache.teacache_skipped_steps cache_slot_no += 1 diff --git a/hyvideo/diffusion/schedulers/__init__.py b/models/hyvideo/diffusion/schedulers/__init__.py similarity index 100% rename from hyvideo/diffusion/schedulers/__init__.py rename to models/hyvideo/diffusion/schedulers/__init__.py diff --git a/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py b/models/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py similarity index 100% rename from hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py rename to models/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py diff --git a/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py similarity index 79% rename from hyvideo/hunyuan.py rename to models/hyvideo/hunyuan.py index 380ec77..a38a7bd 100644 --- a/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -8,24 +8,24 @@ from pathlib import Path from einops import rearrange import torch import torch.distributed as dist -from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V -from hyvideo.vae import load_vae -from hyvideo.modules import load_model -from hyvideo.text_encoder import TextEncoder -from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list -from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new -from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler -from hyvideo.diffusion.pipelines import HunyuanVideoPipeline -from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline +from .constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V +from .vae import load_vae +from .modules import load_model +from .text_encoder import TextEncoder +from .utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list +from .modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new +from .diffusion.schedulers import FlowMatchDiscreteScheduler +from .diffusion.pipelines import HunyuanVideoPipeline +from .diffusion.pipelines import HunyuanVideoAudioPipeline from PIL import Image import numpy as np import torchvision.transforms as transforms import cv2 -from wan.utils.utils import calculate_new_dimensions, convert_tensor_to_image -from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask +from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image +from .data_kits.audio_preprocessor import encode_audio, get_facemask from transformers import WhisperModel from transformers import AutoFeatureExtractor -from hyvideo.data_kits.face_align import AlignImage +from .data_kits.face_align import AlignImage import librosa def get_audio_feature(feature_extractor, audio_path, duration): @@ -66,174 +66,174 @@ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): -def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) +# def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): +# num_images, num_image_patches, embed_dim = image_features.shape +# batch_size, sequence_length = input_ids.shape +# left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) +# # 1. Create a mask to know where special image tokens are +# special_image_token_mask = input_ids == self.config.image_token_index +# num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) +# # Compute the maximum embed dimension +# max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length +# batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] +# # 2. Compute the positions where text should be written +# # Calculate new positions for text tokens in merged image-text sequence. +# # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. +# # `torch.cumsum` computes how each image token shifts subsequent text token positions. +# # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. +# new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 +# nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] +# if left_padding: +# new_token_positions += nb_image_pad[:, None] # offset for left padding +# text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) +# # 3. Create the full embedding, already padded to the maximum position +# final_embedding = torch.zeros( +# batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device +# ) +# final_attention_mask = torch.zeros( +# batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device +# ) +# if labels is not None: +# final_labels = torch.full( +# (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device +# ) +# # In case the Vision model or the Language model has been offloaded to CPU, we need to manually +# # set the corresponding tensors into their correct target device. +# target_device = inputs_embeds.device +# batch_indices, non_image_indices, text_to_overwrite = ( +# batch_indices.to(target_device), +# non_image_indices.to(target_device), +# text_to_overwrite.to(target_device), +# ) +# attention_mask = attention_mask.to(target_device) - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] +# # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] +# # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features +# final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] +# final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] +# if labels is not None: +# final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) +# # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) +# image_to_overwrite = torch.full( +# (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device +# ) +# image_to_overwrite[batch_indices, text_to_overwrite] = False +# image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) +# if image_to_overwrite.sum() != image_features.shape[:-1].numel(): +# raise ValueError( +# f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" +# f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." +# ) - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) +# final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) +# final_attention_mask |= image_to_overwrite +# position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] +# # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. +# batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) +# indices_to_mask = new_token_positions[batch_indices, pad_indices] - final_embedding[batch_indices, indices_to_mask] = 0 +# final_embedding[batch_indices, indices_to_mask] = 0 - if labels is None: - final_labels = None +# if labels is None: +# final_labels = None - return final_embedding, final_attention_mask, final_labels, position_ids +# return final_embedding, final_attention_mask, final_labels, position_ids -def patched_llava_forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, -): - from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast +# def patched_llava_forward( +# self, +# input_ids: torch.LongTensor = None, +# pixel_values: torch.FloatTensor = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[List[torch.FloatTensor]] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# vision_feature_layer: Optional[int] = None, +# vision_feature_select_strategy: Optional[str] = None, +# labels: Optional[torch.LongTensor] = None, +# use_cache: Optional[bool] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# return_dict: Optional[bool] = None, +# cache_position: Optional[torch.LongTensor] = None, +# num_logits_to_keep: int = 0, +# ): +# from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) +# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions +# output_hidden_states = ( +# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states +# ) +# return_dict = return_dict if return_dict is not None else self.config.use_return_dict +# vision_feature_layer = ( +# vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer +# ) +# vision_feature_select_strategy = ( +# vision_feature_select_strategy +# if vision_feature_select_strategy is not None +# else self.config.vision_feature_select_strategy +# ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") +# if (input_ids is None) ^ (inputs_embeds is not None): +# raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) +# if pixel_values is not None and inputs_embeds is not None: +# raise ValueError( +# "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" +# ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) +# if inputs_embeds is None: +# inputs_embeds = self.get_input_embeddings()(input_ids) - image_features = None - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) +# image_features = None +# if pixel_values is not None: +# image_features = self.get_image_features( +# pixel_values=pixel_values, +# vision_feature_layer=vision_feature_layer, +# vision_feature_select_strategy=vision_feature_select_strategy, +# ) - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) +# inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( +# image_features, inputs_embeds, input_ids, attention_mask, labels +# ) +# cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - ) +# outputs = self.language_model( +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# return_dict=return_dict, +# cache_position=cache_position, +# num_logits_to_keep=num_logits_to_keep, +# ) - logits = outputs[0] +# logits = outputs[0] - loss = None +# loss = None - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output +# if not return_dict: +# output = (logits,) + outputs[1:] +# return (loss,) + output if loss is not None else output - return LlavaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) +# return LlavaCausalLMOutputWithPast( +# loss=loss, +# logits=logits, +# past_key_values=outputs.past_key_values, +# hidden_states=outputs.hidden_states, +# attentions=outputs.attentions, +# image_hidden_states=image_features if pixel_values is not None else None, +# ) def adapt_model(model, audio_block_name): modules_dict= { k: m for k, m in model.named_modules()} @@ -320,8 +320,8 @@ class Inference(object): device = "cuda" import transformers - transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) - transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features + # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) + # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features torch.set_grad_enabled(False) text_len = 512 @@ -778,7 +778,7 @@ class HunyuanVideoSampler(Inference): raise ValueError( f"Seed must be an integer, a list of integers, or None, got {seed}." ) - from wan.utils.utils import seed_everything + from shared.utils.utils import seed_everything seed_everything(seed) generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] @@ -956,7 +956,7 @@ class HunyuanVideoSampler(Inference): # out_latents= ref_latents / self.vae.config.scaling_factor # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] # image = image.clamp(-1, 1) - # from wan.utils.utils import cache_video + # from shared.utils.utils import cache_video # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) motion_pose = np.array([25] * 4) @@ -1040,5 +1040,4 @@ class HunyuanVideoSampler(Inference): return samples -def query_model_def(model_type, model_def): - return None \ No newline at end of file + diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py new file mode 100644 index 0000000..dc9ba96 --- /dev/null +++ b/models/hyvideo/hunyuan_handler.py @@ -0,0 +1,167 @@ +import torch + +def get_hunyuan_text_encoder_filename(text_encoder_quantization): + if text_encoder_quantization =="int8": + text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" + else: + text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" + + return text_encoder_filename + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + resolution = inputs["resolution"] + width, height = resolution.split("x") + pixels = int(width) * int(height) + + if cache_type == "mag": + skip_steps_cache.update({ + "magcache_thresh" : 0, + "magcache_K" : 2, + }) + if pixels >= 1280* 720: + skip_steps_cache.def_mag_ratios = [1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456] + else: + skip_steps_cache.def_mag_ratios = [1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592] + else: + skip_steps_cache.coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + + if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]: + fps = 25 + elif base_model_type in ["hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] = fps + extra_model_def["frames_minimum"] = 5 + extra_model_def["frames_steps"] = 4 + extra_model_def["sliding_window"] = False + extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"] + extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"] + extra_model_def["tea_cache"] = True + extra_model_def["mag_cache"] = True + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] + + @staticmethod + def query_family_maps(): + models_eqv_map = { + } + + models_comp_map = { + "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], + } + + return models_eqv_map, models_comp_map + + @staticmethod + def query_model_family(): + return "hunyuan" + + @staticmethod + def query_family_infos(): + return {"hunyuan":(20, "Hunyuan Video")} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan") + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/HunyuanVideo", + "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], + "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], + ["detface.pt"], + [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) + ] + } + + @staticmethod + def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .hunyuan import HunyuanVideoSampler + from mmgp import offload + + hunyuan_model = HunyuanVideoSampler.from_pretrained( + model_filepath = model_filename, + model_type = model_type, + base_model_type = base_model_type, + text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), + dtype = dtype, + quantizeTransformer = quantizeTransformer, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } + + if hunyuan_model.wav2vec != None: + pipe["wav2vec"] = hunyuan_model.wav2vec + + + # if hunyuan_model.align_instance != None: + # pipe["align_instance"] = hunyuan_model.align_instance.facedet.model + + + from .modules.models import get_linear_split_map + + split_linear_modules_map = get_linear_split_map() + hunyuan_model.model.split_linear_modules_map = split_linear_modules_map + offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map ) + + + return hunyuan_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults["embedded_guidance_scale"]= 6.0 + + if base_model_type in ["hunyuan","hunyuan_i2v"]: + ui_defaults.update({ + "guidance_scale": 7.0, + }) + + elif base_model_type in ["hunyuan_custom"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "resolution": "1280x720", + "video_prompt_type": "I", + }) + elif base_model_type in ["hunyuan_custom_audio"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "I", + }) + elif base_model_type in ["hunyuan_custom_edit"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "MVAI", + "sliding_window_size": 129, + }) + elif base_model_type in ["hunyuan_avatar"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "remove_background_images_ref": 0, + "skip_steps_start_step_perc": 25, + "video_length": 129, + "video_prompt_type": "I", + }) diff --git a/hyvideo/modules/__init__.py b/models/hyvideo/modules/__init__.py similarity index 100% rename from hyvideo/modules/__init__.py rename to models/hyvideo/modules/__init__.py diff --git a/hyvideo/modules/activation_layers.py b/models/hyvideo/modules/activation_layers.py similarity index 100% rename from hyvideo/modules/activation_layers.py rename to models/hyvideo/modules/activation_layers.py diff --git a/hyvideo/modules/attenion.py b/models/hyvideo/modules/attenion.py similarity index 100% rename from hyvideo/modules/attenion.py rename to models/hyvideo/modules/attenion.py diff --git a/hyvideo/modules/audio_adapters.py b/models/hyvideo/modules/audio_adapters.py similarity index 100% rename from hyvideo/modules/audio_adapters.py rename to models/hyvideo/modules/audio_adapters.py diff --git a/hyvideo/modules/embed_layers.py b/models/hyvideo/modules/embed_layers.py similarity index 100% rename from hyvideo/modules/embed_layers.py rename to models/hyvideo/modules/embed_layers.py diff --git a/hyvideo/modules/mlp_layers.py b/models/hyvideo/modules/mlp_layers.py similarity index 100% rename from hyvideo/modules/mlp_layers.py rename to models/hyvideo/modules/mlp_layers.py diff --git a/hyvideo/modules/models.py b/models/hyvideo/modules/models.py similarity index 93% rename from hyvideo/modules/models.py rename to models/hyvideo/modules/models.py index 48978a9..4cdce4a 100644 --- a/hyvideo/modules/models.py +++ b/models/hyvideo/modules/models.py @@ -18,7 +18,7 @@ from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, appl from .token_refiner import SingleTokenRefiner import numpy as np from mmgp import offload -from wan.modules.attention import pay_attention +from shared.attention import pay_attention from .audio_adapters import AudioProjNet2, PerceiverAttentionCA def get_linear_split_map(): @@ -794,6 +794,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): block.disable_deterministic() def compute_magcache_threshold(self, start_step, num_inference_steps = 0, speed_factor =0): + skips_step_cache = self.cache + def nearest_interp(src_array, target_length): src_length = len(src_array) if target_length == 1: @@ -801,11 +803,11 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): scale = (src_length - 1) / (target_length - 1) mapped_indices = np.round(np.arange(target_length) * scale).astype(int) return src_array[mapped_indices] - - if len(self.def_mag_ratios) != num_inference_steps: - self.mag_ratios = nearest_interp(self.def_mag_ratios, num_inference_steps) + def_mag_ratios = np.array([1.0]+ skips_step_cache.def_mag_ratios) + if len(def_mag_ratios) != num_inference_steps: + skips_step_cache.mag_ratios = nearest_interp(def_mag_ratios, num_inference_steps) else: - self.mag_ratios = self.def_mag_ratios + skips_step_cache.mag_ratios = def_mag_ratios best_deltas = None best_threshold = 0.01 @@ -821,12 +823,12 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): if i<=start_step: skip = False else: - cur_mag_ratio = self.mag_ratios[i] # conditional and unconditional in one list + cur_mag_ratio = skips_step_cache.mag_ratios[i] # conditional and unconditional in one list accumulated_ratio *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_steps += 1 # skip steps plus 1 cur_skip_err = np.abs(1-accumulated_ratio) # skip error of current steps accumulated_err += cur_skip_err # accumulated error of multiple steps - if accumulated_err best_diff: break threshold += 0.01 - self.magcache_thresh = best_threshold + skips_step_cache.magcache_thresh = best_threshold print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{num_inference_steps/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold @@ -969,23 +971,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): attn_mask = None freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None - - - if self.enable_cache: + should_calc = True + skip_steps_cache = self.cache + if skip_steps_cache is not None: + cache_type = skip_steps_cache.cache_type if x_id == 0: - self.should_calc = True - if self.enable_cache == "mag": - if step_no > self.cache_start_step: - cur_mag_ratio = self.mag_ratios[step_no] - self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio - cur_skip_err = np.abs(1-self.accumulated_ratio) - self.accumulated_err += cur_skip_err - self.accumulated_steps += 1 - if self.accumulated_err<=self.magcache_thresh and self.accumulated_steps<=self.magcache_K: - self.should_calc = False - self.cache_skipped_steps += 1 + skip_steps_cache.should_calc = True + if cache_type == "mag": + if step_no > skip_steps_cache.start_step: + cur_mag_ratio = skip_steps_cache.mag_ratios[step_no] + skip_steps_cache.accumulated_ratio = skip_steps_cache.accumulated_ratio*cur_mag_ratio + cur_skip_err = np.abs(1-skip_steps_cache.accumulated_ratio) + skip_steps_cache.accumulated_err += cur_skip_err + skip_steps_cache.accumulated_steps += 1 + if skip_steps_cache.accumulated_err<=skip_steps_cache.magcache_thresh and skip_steps_cache.accumulated_steps<=skip_steps_cache.magcache_K: + skip_steps_cache.should_calc = False + skip_steps_cache.skipped_steps += 1 else: - self.accumulated_ratio, self.accumulated_steps, self.accumulated_err = 1.0, 0, 0 + skip_steps_cache.accumulated_ratio, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_err = 1.0, 0, 0 else: inp = img[0:1] vec_ = vec[0:1] @@ -994,26 +997,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): normed_inp = normed_inp.to(torch.bfloat16) modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale ) del normed_inp, img_mod1_shift, img_mod1_scale - if step_no <= self.cache_start_step or step_no == self.num_steps-1: - self.accumulated_rel_l1_distance = 0 - else: - coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] - rescale_func = np.poly1d(coefficients) - self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) - if self.accumulated_rel_l1_distance < self.rel_l1_thresh: - self.should_calc = False - self.cache_skipped_steps += 1 + if step_no <= skip_steps_cache.start_step or step_no == skip_steps_cache.num_steps-1: + skip_steps_cache.accumulated_rel_l1_distance = 0 + else: + rescale_func = np.poly1d(skip_steps_cache.coefficients) + skip_steps_cache.accumulated_rel_l1_distance += rescale_func(((modulated_inp-skip_steps_cache.previous_modulated_input).abs().mean() / skip_steps_cache.previous_modulated_input.abs().mean()).cpu().item()) + if skip_steps_cache.accumulated_rel_l1_distance < skip_steps_cache.rel_l1_thresh: + skip_steps_cache.should_calc = False + skip_steps_cache.skipped_steps += 1 else: - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - else: - self.should_calc = True + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.previous_modulated_input = modulated_inp + should_calc = skip_steps_cache.should_calc - if not self.should_calc: - img += self.previous_residual[x_id] + if not should_calc: + img += skip_steps_cache.previous_residual[x_id] else: - if self.enable_cache: - self.previous_residual[x_id] = None + if skip_steps_cache is not None: + skip_steps_cache.previous_residual[x_id] = None ori_img = img[0:1].clone() # --------------------- Pass through DiT blocks ------------------------ for layer_num, block in enumerate(self.double_blocks): @@ -1076,10 +1077,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): single_block_args = None # img = x[:, :img_seq_len, ...] - if self.enable_cache: + if skip_steps_cache is not None: if len(img) > 1: - self.previous_residual[0] = torch.empty_like(img) - for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])): + skip_steps_cache.previous_residual[0] = torch.empty_like(img) + for i, (x, residual) in enumerate(zip(img, skip_steps_cache.previous_residual[0])): if i < len(img) - 1: residual[...] = torch.sub(x, ori_img) else: @@ -1087,8 +1088,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): torch.sub(x, ori_img, out=residual) x = None else: - self.previous_residual[x_id] = ori_img - torch.sub(img, ori_img, out=self.previous_residual[x_id]) + skip_steps_cache.previous_residual[x_id] = ori_img + torch.sub(img, ori_img, out=skip_steps_cache.previous_residual[x_id]) if ref_length != None: diff --git a/hyvideo/modules/modulate_layers.py b/models/hyvideo/modules/modulate_layers.py similarity index 100% rename from hyvideo/modules/modulate_layers.py rename to models/hyvideo/modules/modulate_layers.py diff --git a/hyvideo/modules/norm_layers.py b/models/hyvideo/modules/norm_layers.py similarity index 100% rename from hyvideo/modules/norm_layers.py rename to models/hyvideo/modules/norm_layers.py diff --git a/hyvideo/modules/original models.py b/models/hyvideo/modules/original models.py similarity index 100% rename from hyvideo/modules/original models.py rename to models/hyvideo/modules/original models.py diff --git a/hyvideo/modules/placement.py b/models/hyvideo/modules/placement.py similarity index 100% rename from hyvideo/modules/placement.py rename to models/hyvideo/modules/placement.py diff --git a/hyvideo/modules/posemb_layers.py b/models/hyvideo/modules/posemb_layers.py similarity index 100% rename from hyvideo/modules/posemb_layers.py rename to models/hyvideo/modules/posemb_layers.py diff --git a/hyvideo/modules/token_refiner.py b/models/hyvideo/modules/token_refiner.py similarity index 100% rename from hyvideo/modules/token_refiner.py rename to models/hyvideo/modules/token_refiner.py diff --git a/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py similarity index 100% rename from hyvideo/modules/utils.py rename to models/hyvideo/modules/utils.py diff --git a/hyvideo/prompt_rewrite.py b/models/hyvideo/prompt_rewrite.py similarity index 100% rename from hyvideo/prompt_rewrite.py rename to models/hyvideo/prompt_rewrite.py diff --git a/hyvideo/text_encoder/__init__.py b/models/hyvideo/text_encoder/__init__.py similarity index 97% rename from hyvideo/text_encoder/__init__.py rename to models/hyvideo/text_encoder/__init__.py index 1376718..9bd47d4 100644 --- a/hyvideo/text_encoder/__init__.py +++ b/models/hyvideo/text_encoder/__init__.py @@ -15,6 +15,7 @@ from transformers.utils import ModelOutput from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH from ..constants import PRECISION_TO_TYPE +from .llava.modeling_llava import LlavaForConditionalGeneration def use_default(value, default): @@ -188,10 +189,16 @@ class TextEncoder(nn.Module): if "llm" in text_encoder_type: from mmgp import offload - forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" - self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) - if forcedConfigPath != None: + # forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" + # self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) + + if "i2v" in text_encoder_type: + self.model= offload.fast_load_transformers_model(self.model_path, modelClass= LlavaForConditionalGeneration) + else: + self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model", forcedConfigPath = "ckpts/llava-llama-3-8b/config.json") self.model.final_layer_norm = self.model.model.norm + + else: self.model, self.model_path = load_text_encoder( diff --git a/models/hyvideo/text_encoder/llava/__init__.py b/models/hyvideo/text_encoder/llava/__init__.py new file mode 100644 index 0000000..e6d2f52 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from typing import TYPE_CHECKING + +# from ...utils import _LazyModule +# from ...utils.import_utils import define_import_structure + + +# if TYPE_CHECKING: +# from .configuration_llava import * +# from .image_processing_llava_fast import * +# from .modeling_llava import * +# from .processing_llava import * +# else: +# import sys + +# _file = globals()["__file__"] + # sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/models/hyvideo/text_encoder/llava/configuration_llava.py b/models/hyvideo/text_encoder/llava/configuration_llava.py new file mode 100644 index 0000000..9c30798 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/configuration_llava.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llava model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.models.auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class LlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an + Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llava-9B. + + e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llava" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + is_composition = True + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaConfig"] diff --git a/models/hyvideo/text_encoder/llava/image_processing_llava.py b/models/hyvideo/text_encoder/llava/image_processing_llava.py new file mode 100644 index 0000000..37ef079 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/image_processing_llava.py @@ -0,0 +1,436 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for LLaVa.""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class LlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a LLaVa image processor. + + Args: + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image to a square based on the longest edge. + The padding value is determined by the `image_mean` parameter. + Can be overridden by `do_pad` in the `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_pad: bool = False, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_pad = do_pad + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_pad", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def pad_to_square( + self, + image: np.ndarray, + background_color: Union[int, Tuple[int, int, int]] = 0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Pads an image to a square based on the longest edge. + + Args: + image (`np.ndarray`): + The image to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + """ + height, width = get_image_size(image, input_data_format) + num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1] + + if height == width: + image = ( + to_channel_dimension_format(image, data_format, input_data_format) + if data_format is not None + else image + ) + return image + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + if input_data_format == ChannelDimension.FIRST: + result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype) + for i, color in enumerate(background_color): + result[i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + result[:, start : start + height, :] = image + else: + start = (max_dim - width) // 2 + result[:, :, start : start + width] = image + else: + result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype) + for i, color in enumerate(background_color): + result[:, :, i] = color + if width > height: + start = (max_dim - height) // 2 + result[start : start + height, :, :] = image + else: + start = (max_dim - width) // 2 + result[:, start : start + width, :] = image + + image = ( + to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result + ) + return image + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_pad: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. + The padding value is determined by the `image_mean` parameter. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_pad = do_pad if do_pad is not None else self.do_pad + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + # we don't pass `do_pad` here since LLaVa uses a custom padding to a square + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + processed_images = [] + for image in images: + if do_pad: + image = self.pad_to_square( + image=image, + background_color=tuple(int(x * 255) for x in self.image_mean), + input_data_format=input_data_format, + ) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["LlavaImageProcessor"] diff --git a/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py b/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py new file mode 100644 index 0000000..d85eb89 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LLaVa.""" + +from typing import List, Optional, Tuple, Union + +from ...image_processing_utils import ( + BatchFeature, +) +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, +) + + +if is_vision_available(): + from ...image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_pad: Optional[bool] + + +@add_start_docstrings( + "Constructs a fast Llava image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter + """, +) +class LlavaImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_pad = False + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = LlavaFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, Tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + height, width = get_image_size(images, ChannelDimension.FIRST) + + if height == width: + return images + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + max_dim = max(height, width) + paste_x_left = (max_dim - width) // 2 + paste_y_left = (max_dim - height) // 2 + paste_x_right = max_dim - width - paste_x_left + paste_y_right = max_dim - height - paste_y_left + padded_images = F.pad( + images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color + ) + + return padded_images + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_pad: bool, + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_pad: + stacked_images = self.pad_to_square( + images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean) + ) + resized_images_grouped[shape] = stacked_images + padded_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for batched resizing + # Needed in case do_pad is False, or padding returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(padded_images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["LlavaImageProcessorFast"] diff --git a/models/hyvideo/text_encoder/llava/modeling_llava.py b/models/hyvideo/text_encoder/llava/modeling_llava.py new file mode 100644 index 0000000..f4ae058 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/modeling_llava.py @@ -0,0 +1,531 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Llava model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import ModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.models.auto import AutoModel, AutoModelForCausalLM +from .configuration_llava import LlavaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlavaConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf" + + +@dataclass +class LlavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_START_DOCSTRING, +) +class LlavaPreTrainedModel(PreTrainedModel): + config_class = LlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + LLAVA_START_DOCSTRING, +) +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): + def __init__(self, config: LlavaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + # For default; crop CLS from each hidden state in the hidden state pool + if vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + # @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + # @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) + # @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"] diff --git a/models/hyvideo/text_encoder/llava/processing_llava.py b/models/hyvideo/text_encoder/llava/processing_llava.py new file mode 100644 index 0000000..6253e19 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/processing_llava.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + +from typing import List, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class LlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + +class LlavaProcessor(ProcessorMixin): + r""" + Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor. + + [`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. + + Args: + image_processor ([`LlavaImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*): + Patch size from the vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Shoudl be same as in model's config + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "num_additional_image_tokens", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size=None, + vision_feature_select_strategy=None, + chat_template=None, + image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + num_additional_image_tokens=0, + **kwargs, + ): + self.patch_size = patch_size + self.num_additional_image_tokens = num_additional_image_tokens + self.vision_feature_select_strategy = vision_feature_select_strategy + self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[LlavaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + LlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * ( + width // self.patch_size + ) + self.num_additional_image_tokens + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) + + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return BatchFeature(data={**text_inputs, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["LlavaProcessor"] diff --git a/hyvideo/utils/__init__.py b/models/hyvideo/utils/__init__.py similarity index 100% rename from hyvideo/utils/__init__.py rename to models/hyvideo/utils/__init__.py diff --git a/hyvideo/utils/data_utils.py b/models/hyvideo/utils/data_utils.py similarity index 100% rename from hyvideo/utils/data_utils.py rename to models/hyvideo/utils/data_utils.py diff --git a/hyvideo/utils/file_utils.py b/models/hyvideo/utils/file_utils.py similarity index 100% rename from hyvideo/utils/file_utils.py rename to models/hyvideo/utils/file_utils.py diff --git a/hyvideo/utils/helpers.py b/models/hyvideo/utils/helpers.py similarity index 100% rename from hyvideo/utils/helpers.py rename to models/hyvideo/utils/helpers.py diff --git a/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py b/models/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py similarity index 100% rename from hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py rename to models/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py diff --git a/hyvideo/vae/__init__.py b/models/hyvideo/vae/__init__.py similarity index 100% rename from hyvideo/vae/__init__.py rename to models/hyvideo/vae/__init__.py diff --git a/hyvideo/vae/autoencoder_kl_causal_3d.py b/models/hyvideo/vae/autoencoder_kl_causal_3d.py similarity index 100% rename from hyvideo/vae/autoencoder_kl_causal_3d.py rename to models/hyvideo/vae/autoencoder_kl_causal_3d.py diff --git a/hyvideo/vae/unet_causal_3d_blocks.py b/models/hyvideo/vae/unet_causal_3d_blocks.py similarity index 100% rename from hyvideo/vae/unet_causal_3d_blocks.py rename to models/hyvideo/vae/unet_causal_3d_blocks.py diff --git a/hyvideo/vae/vae.py b/models/hyvideo/vae/vae.py similarity index 100% rename from hyvideo/vae/vae.py rename to models/hyvideo/vae/vae.py diff --git a/models/ltx_video/__init__.py b/models/ltx_video/__init__.py new file mode 100644 index 0000000..3a3898e --- /dev/null +++ b/models/ltx_video/__init__.py @@ -0,0 +1,2 @@ +from .ltxv import LTXV +from . import ltxv_handler \ No newline at end of file diff --git a/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-dev.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml b/models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.8-dev.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml b/models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml diff --git a/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml b/models/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml similarity index 100% rename from ltx_video/configs/ltxv-2b-0.9.6-dev.yaml rename to models/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml diff --git a/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml b/models/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml similarity index 100% rename from ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml rename to models/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml diff --git a/ltx_video/ltxv.py b/models/ltx_video/ltxv.py similarity index 98% rename from ltx_video/ltxv.py rename to models/ltx_video/ltxv.py index 34bae13..e71ac4f 100644 --- a/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -7,7 +7,7 @@ from pathlib import Path from diffusers.utils import logging from typing import Optional, List, Union import yaml -from wan.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions import imageio import json import numpy as np @@ -605,16 +605,4 @@ def load_media_file( raise Exception("video format not supported") return media_tensor -def query_model_def(model_type, model_def): - LTXV_config = model_def.get("LTXV_config", "") - distilled= "distilled" in LTXV_config - model_def_output = { - "no_guidance": True, - } - if distilled: - model_def_output.update({ - "lock_inference_steps": True, - "no_negative_prompt" : True, - }) - - return model_def_output \ No newline at end of file + diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py new file mode 100644 index 0000000..cfdd069 --- /dev/null +++ b/models/ltx_video/ltxv_handler.py @@ -0,0 +1,92 @@ +import torch + + +def get_ltxv_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + LTXV_config = model_def.get("LTXV_config", "") + distilled= "distilled" in LTXV_config + extra_model_def = { + "no_guidance": True, + } + if distilled: + extra_model_def.update({ + "lock_inference_steps": True, + "no_negative_prompt" : True, + }) + + + extra_model_def["fps"] = 30 + extra_model_def["frames_minimum"] = 17 + extra_model_def["frames_steps"] = 8 + extra_model_def["sliding_window"] = True + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["ltxv_13B"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv") + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_family(): + return "ltxv" + + @staticmethod + def query_family_infos(): + return {"ltxv":(10, "LTX Video")} + + @staticmethod + def get_vae_block_size(base_model_type): + return 32 + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : ["T5_xxl_1.1", "" ], + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] + } + + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .ltxv import LTXV + + ltxv_model = LTXV( + model_filepath = model_filename, + text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), + model_type = model_type, + base_model_type = base_model_type, + model_def = model_def, + dtype = dtype, + # quantizeTransformer = quantizeTransformer, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer + ) + + pipeline = ltxv_model.pipeline + pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler} + + return ltxv_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + pass + \ No newline at end of file diff --git a/ltx_video/__init__.py b/models/ltx_video/models/__init__.py similarity index 100% rename from ltx_video/__init__.py rename to models/ltx_video/models/__init__.py diff --git a/ltx_video/models/__init__.py b/models/ltx_video/models/autoencoders/__init__.py similarity index 100% rename from ltx_video/models/__init__.py rename to models/ltx_video/models/autoencoders/__init__.py diff --git a/ltx_video/models/autoencoders/causal_conv3d.py b/models/ltx_video/models/autoencoders/causal_conv3d.py similarity index 100% rename from ltx_video/models/autoencoders/causal_conv3d.py rename to models/ltx_video/models/autoencoders/causal_conv3d.py diff --git a/ltx_video/models/autoencoders/causal_video_autoencoder.py b/models/ltx_video/models/autoencoders/causal_video_autoencoder.py similarity index 99% rename from ltx_video/models/autoencoders/causal_video_autoencoder.py rename to models/ltx_video/models/autoencoders/causal_video_autoencoder.py index 5f05932..daed704 100644 --- a/ltx_video/models/autoencoders/causal_video_autoencoder.py +++ b/models/ltx_video/models/autoencoders/causal_video_autoencoder.py @@ -15,12 +15,12 @@ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbedding from safetensors import safe_open -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from ltx_video.models.autoencoders.pixel_norm import PixelNorm -from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND -from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper -from ltx_video.models.transformers.attention import Attention -from ltx_video.utils.diffusers_config_mapping import ( +from ..autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ...models.autoencoders.pixel_norm import PixelNorm +from ...models.autoencoders.pixel_shuffle import PixelShuffleND +from ...models.autoencoders.vae import AutoencoderKLWrapper +from ...models.transformers.attention import Attention +from ...utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, VAE_KEYS_RENAME_DICT, diff --git a/ltx_video/models/autoencoders/conv_nd_factory.py b/models/ltx_video/models/autoencoders/conv_nd_factory.py similarity index 94% rename from ltx_video/models/autoencoders/conv_nd_factory.py rename to models/ltx_video/models/autoencoders/conv_nd_factory.py index 718c69b..59a3fc0 100644 --- a/ltx_video/models/autoencoders/conv_nd_factory.py +++ b/models/ltx_video/models/autoencoders/conv_nd_factory.py @@ -2,8 +2,8 @@ from typing import Tuple, Union import torch -from ltx_video.models.autoencoders.dual_conv3d import DualConv3d -from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d +from ..autoencoders.dual_conv3d import DualConv3d +from ..autoencoders.causal_conv3d import CausalConv3d def make_conv_nd( diff --git a/ltx_video/models/autoencoders/dual_conv3d.py b/models/ltx_video/models/autoencoders/dual_conv3d.py similarity index 100% rename from ltx_video/models/autoencoders/dual_conv3d.py rename to models/ltx_video/models/autoencoders/dual_conv3d.py diff --git a/ltx_video/models/autoencoders/latent_upsampler.py b/models/ltx_video/models/autoencoders/latent_upsampler.py similarity index 98% rename from ltx_video/models/autoencoders/latent_upsampler.py rename to models/ltx_video/models/autoencoders/latent_upsampler.py index 4a76bc2..f666d2f 100644 --- a/ltx_video/models/autoencoders/latent_upsampler.py +++ b/models/ltx_video/models/autoencoders/latent_upsampler.py @@ -9,7 +9,7 @@ from einops import rearrange from diffusers import ConfigMixin, ModelMixin from safetensors.torch import safe_open -from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND +from ...models.autoencoders.pixel_shuffle import PixelShuffleND class ResBlock(nn.Module): diff --git a/ltx_video/models/autoencoders/pixel_norm.py b/models/ltx_video/models/autoencoders/pixel_norm.py similarity index 100% rename from ltx_video/models/autoencoders/pixel_norm.py rename to models/ltx_video/models/autoencoders/pixel_norm.py diff --git a/ltx_video/models/autoencoders/pixel_shuffle.py b/models/ltx_video/models/autoencoders/pixel_shuffle.py similarity index 100% rename from ltx_video/models/autoencoders/pixel_shuffle.py rename to models/ltx_video/models/autoencoders/pixel_shuffle.py diff --git a/ltx_video/models/autoencoders/vae.py b/models/ltx_video/models/autoencoders/vae.py similarity index 99% rename from ltx_video/models/autoencoders/vae.py rename to models/ltx_video/models/autoencoders/vae.py index c1135ba..a0ce1c4 100644 --- a/ltx_video/models/autoencoders/vae.py +++ b/models/ltx_video/models/autoencoders/vae.py @@ -10,7 +10,7 @@ from diffusers.models.autoencoders.vae import ( DiagonalGaussianDistribution, ) from diffusers.models.modeling_outputs import AutoencoderKLOutput -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd +from ...models.autoencoders.conv_nd_factory import make_conv_nd class AutoencoderKLWrapper(ModelMixin, ConfigMixin): diff --git a/ltx_video/models/autoencoders/vae_encode.py b/models/ltx_video/models/autoencoders/vae_encode.py similarity index 98% rename from ltx_video/models/autoencoders/vae_encode.py rename to models/ltx_video/models/autoencoders/vae_encode.py index b7d2476..4b6a5c4 100644 --- a/ltx_video/models/autoencoders/vae_encode.py +++ b/models/ltx_video/models/autoencoders/vae_encode.py @@ -5,10 +5,10 @@ from einops import rearrange from torch import Tensor -from ltx_video.models.autoencoders.causal_video_autoencoder import ( +from ...models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from ltx_video.models.autoencoders.video_autoencoder import ( +from ...models.autoencoders.video_autoencoder import ( Downsample3D, VideoAutoencoder, ) diff --git a/ltx_video/models/autoencoders/video_autoencoder.py b/models/ltx_video/models/autoencoders/video_autoencoder.py similarity index 99% rename from ltx_video/models/autoencoders/video_autoencoder.py rename to models/ltx_video/models/autoencoders/video_autoencoder.py index 3c7926c..dbb2bcd 100644 --- a/ltx_video/models/autoencoders/video_autoencoder.py +++ b/models/ltx_video/models/autoencoders/video_autoencoder.py @@ -11,10 +11,10 @@ from torch.nn import functional from diffusers.utils import logging -from ltx_video.utils.torch_utils import Identity -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from ltx_video.models.autoencoders.pixel_norm import PixelNorm -from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper +from ...utils.torch_utils import Identity +from ...models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ...models.autoencoders.pixel_norm import PixelNorm +from ...models.autoencoders.vae import AutoencoderKLWrapper logger = logging.get_logger(__name__) diff --git a/ltx_video/models/autoencoders/__init__.py b/models/ltx_video/models/transformers/__init__.py similarity index 100% rename from ltx_video/models/autoencoders/__init__.py rename to models/ltx_video/models/transformers/__init__.py diff --git a/ltx_video/models/transformers/attention.py b/models/ltx_video/models/transformers/attention.py similarity index 99% rename from ltx_video/models/transformers/attention.py rename to models/ltx_video/models/transformers/attention.py index a7b4555..a87a8a0 100644 --- a/ltx_video/models/transformers/attention.py +++ b/models/ltx_video/models/transformers/attention.py @@ -19,15 +19,9 @@ from diffusers.utils import deprecate, logging from diffusers.utils.torch_utils import maybe_allow_in_graph from einops import rearrange from torch import nn -from wan.modules.attention import pay_attention -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from shared.attention import pay_attention +from ...utils.skip_layer_strategy import SkipLayerStrategy -try: - from torch_xla.experimental.custom_kernel import flash_attention -except ImportError: - # workaround for automatic tests. Currently this function is manually patched - # to the torch_xla lib on setup of container - pass # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py diff --git a/ltx_video/models/transformers/embeddings.py b/models/ltx_video/models/transformers/embeddings.py similarity index 100% rename from ltx_video/models/transformers/embeddings.py rename to models/ltx_video/models/transformers/embeddings.py diff --git a/ltx_video/models/transformers/symmetric_patchifier.py b/models/ltx_video/models/transformers/symmetric_patchifier.py similarity index 100% rename from ltx_video/models/transformers/symmetric_patchifier.py rename to models/ltx_video/models/transformers/symmetric_patchifier.py diff --git a/ltx_video/models/transformers/transformer3d.py b/models/ltx_video/models/transformers/transformer3d.py similarity index 98% rename from ltx_video/models/transformers/transformer3d.py rename to models/ltx_video/models/transformers/transformer3d.py index e182f21..c90baeb 100644 --- a/ltx_video/models/transformers/transformer3d.py +++ b/models/ltx_video/models/transformers/transformer3d.py @@ -16,10 +16,10 @@ from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils import logging from torch import nn from safetensors import safe_open -from ltx_video.models.transformers.attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from .attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape +from ...utils.skip_layer_strategy import SkipLayerStrategy -from ltx_video.utils.diffusers_config_mapping import ( +from ...utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, TRANSFORMER_KEYS_RENAME_DICT, diff --git a/ltx_video/models/transformers/__init__.py b/models/ltx_video/pipelines/__init__.py similarity index 100% rename from ltx_video/models/transformers/__init__.py rename to models/ltx_video/pipelines/__init__.py diff --git a/ltx_video/pipelines/crf_compressor.py b/models/ltx_video/pipelines/crf_compressor.py similarity index 100% rename from ltx_video/pipelines/crf_compressor.py rename to models/ltx_video/pipelines/crf_compressor.py diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/models/ltx_video/pipelines/pipeline_ltx_video.py similarity index 99% rename from ltx_video/pipelines/pipeline_ltx_video.py rename to models/ltx_video/pipelines/pipeline_ltx_video.py index 8bb6d27..f98eb13 100644 --- a/ltx_video/pipelines/pipeline_ltx_video.py +++ b/models/ltx_video/pipelines/pipeline_ltx_video.py @@ -24,22 +24,22 @@ from transformers import ( AutoTokenizer, ) -from ltx_video.models.autoencoders.causal_video_autoencoder import ( +from ..models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from ltx_video.models.autoencoders.vae_encode import ( +from ..models.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, vae_decode, vae_encode, ) -from ltx_video.models.transformers.symmetric_patchifier import Patchifier -from ltx_video.models.transformers.transformer3d import Transformer3DModel -from ltx_video.schedulers.rf import TimestepShifter -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy -from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt -from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler -from ltx_video.models.autoencoders.vae_encode import ( +from ..models.transformers.symmetric_patchifier import Patchifier +from ..models.transformers.transformer3d import Transformer3DModel +from ..schedulers.rf import TimestepShifter +from ..utils.skip_layer_strategy import SkipLayerStrategy +from ..utils.prompt_enhance_utils import generate_cinematic_prompt +from ..models.autoencoders.latent_upsampler import LatentUpsampler +from ..models.autoencoders.vae_encode import ( un_normalize_latents, normalize_latents, ) diff --git a/ltx_video/pipelines/__init__.py b/models/ltx_video/schedulers/__init__.py similarity index 100% rename from ltx_video/pipelines/__init__.py rename to models/ltx_video/schedulers/__init__.py diff --git a/ltx_video/schedulers/rf.py b/models/ltx_video/schedulers/rf.py similarity index 99% rename from ltx_video/schedulers/rf.py rename to models/ltx_video/schedulers/rf.py index 2cf99da..bced26a 100644 --- a/ltx_video/schedulers/rf.py +++ b/models/ltx_video/schedulers/rf.py @@ -14,9 +14,9 @@ from torch import Tensor from safetensors import safe_open -from ltx_video.utils.torch_utils import append_dims +from ..utils.torch_utils import append_dims -from ltx_video.utils.diffusers_config_mapping import ( +from ..utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, ) diff --git a/ltx_video/schedulers/__init__.py b/models/ltx_video/utils/__init__.py similarity index 100% rename from ltx_video/schedulers/__init__.py rename to models/ltx_video/utils/__init__.py diff --git a/ltx_video/utils/diffusers_config_mapping.py b/models/ltx_video/utils/diffusers_config_mapping.py similarity index 100% rename from ltx_video/utils/diffusers_config_mapping.py rename to models/ltx_video/utils/diffusers_config_mapping.py diff --git a/ltx_video/utils/prompt_enhance_utils.py b/models/ltx_video/utils/prompt_enhance_utils.py similarity index 100% rename from ltx_video/utils/prompt_enhance_utils.py rename to models/ltx_video/utils/prompt_enhance_utils.py diff --git a/ltx_video/utils/skip_layer_strategy.py b/models/ltx_video/utils/skip_layer_strategy.py similarity index 100% rename from ltx_video/utils/skip_layer_strategy.py rename to models/ltx_video/utils/skip_layer_strategy.py diff --git a/ltx_video/utils/torch_utils.py b/models/ltx_video/utils/torch_utils.py similarity index 100% rename from ltx_video/utils/torch_utils.py rename to models/ltx_video/utils/torch_utils.py diff --git a/models/qwen/autoencoder_kl_qwenimage.py b/models/qwen/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000..d144284 --- /dev/null +++ b/models/qwen/autoencoder_kl_qwenimage.py @@ -0,0 +1,1096 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - arXiv: https://arxiv.org/abs/2503.20314 + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + @staticmethod + def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision): + + # VAE Tiling + if vae_config == 0: + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + use_tiling = False + tile_sample_min_width = 256 + + if use_vae_config == 1: + use_tiling = False + elif use_vae_config == 2: + use_tiling = True + tile_sample_min_width = 256 + + return (use_tiling, tile_sample_min_width) + + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py new file mode 100644 index 0000000..ad39ae5 --- /dev/null +++ b/models/qwen/pipeline_qwenimage.py @@ -0,0 +1,739 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mmgp import offload +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch, json + +from diffusers.image_processor import VaeImageProcessor +from .transformer_qwenimage import QwenImageTransformer2DModel + +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from diffusers import FlowMatchEulerDiscreteScheduler + +XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/QwenImage-20B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImagePipeline(): #DiffusionPipeline + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + vae, + text_encoder, + tokenizer, + transformer, + ): + + self.vae=vae + self.text_encoder=text_encoder + self.tokenizer=tokenizer + self.transformer=transformer + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + callback=None, + pipeline=None, + loras_slists=None, + joint_pass= True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + kwargs = {'pipeline': pipeline, 'callback': callback} + if callback != None: + callback(-1, None, True) + + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = "cuda" + # device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + dtype = torch.bfloat16 + prompt_embeds = prompt_embeds.to(dtype) + if do_true_cfg: + negative_prompt_embeds = negative_prompt_embeds.to(dtype) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + updated_num_steps= len(timesteps) + if callback != None: + from shared.utils.loras_mutipliers import update_loras_slists + update_loras_slists(self.transformer, loras_slists, updated_num_steps) + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist(),negative_prompt_embeds_mask.sum(dim=1).tolist()], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None + else: + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist()], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[negative_prompt_embeds_mask], + encoder_hidden_states_list=[negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[negative_prompt_embeds_mask.sum(dim=1).tolist()], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if neg_noise_pred == None: return None + + if do_true_cfg: + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + if comb_pred == None: return None + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + neg_noise_pred = None + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback is not None: + # preview = unpack_latent(img).transpose(0,1) + callback(i, None, False) + + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + + + return image diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py new file mode 100644 index 0000000..6db6a76 --- /dev/null +++ b/models/qwen/qwen_handler.py @@ -0,0 +1,86 @@ +import torch + +def get_qwen_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + model_def_output = { + "image_outputs" : True, + "sample_solvers":[ + ("Default", "default"), + ("Lightning", "lightning")] + } + + + return model_def_output + + @staticmethod + def query_supported_types(): + return ["qwen_image_20B"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def query_model_family(): + return "qwen" + + @staticmethod + def query_family_infos(): + return {"qwen":(40, "Qwen")} + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_qwen_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/Qwen_image", + "sourceFolderList" : ["", "Qwen2.5-VL-7B-Instruct"], + "fileList" : [ ["qwen_vae.safetensors", "qwen_vae_config.json"], ["merges.txt", "tokenizer_config.json", "config.json", "vocab.json"] + computeList(text_encoder_filename) ] + } + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .qwen_main import model_factory + from mmgp import offload + + pipe_processor = model_factory( + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= get_qwen_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"tokenizer" : pipe_processor.tokenizer, "transformer" : pipe_processor.transformer, "text_encoder" : pipe_processor.text_encoder, "vae" : pipe_processor.vae} + + return pipe_processor, pipe + + + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if ui_defaults.get("sample_solver", "") == "": + ui_defaults["sample_solver"] = "default" + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "guidance_scale": 4, + "sample_solver": "default", + }) + if model_def.get("reference_image", False): + ui_defaults.update({ + "video_prompt_type": "KI", + }) + diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py new file mode 100644 index 0000000..ccaa758 --- /dev/null +++ b/models/qwen/qwen_main.py @@ -0,0 +1,160 @@ + +from mmgp import offload +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch, json, os +import math + +from diffusers.image_processor import VaeImageProcessor +from .transformer_qwenimage import QwenImageTransformer2DModel + +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from diffusers import FlowMatchEulerDiscreteScheduler +from .pipeline_qwenimage import QwenImagePipeline + +class model_factory(): + def __init__( + self, + checkpoint_dir, + model_filename = None, + model_type = None, + model_def = None, + base_model_type = None, + text_encoder_filename = None, + quantizeTransformer = False, + save_quantized = False, + dtype = torch.bfloat16, + VAE_dtype = torch.float32, + mixed_precision_transformer = False + ): + + + transformer_filename = model_filename[0] + tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) + + with open("configs/qwen_image_20B.json", 'r', encoding='utf-8') as f: + transformer_config = json.load(f) + transformer_config.pop("_diffusers_version") + transformer_config.pop("_class_name") + transformer_config.pop("pooled_projection_dim") + + from accelerate import init_empty_weights + with init_empty_weights(): + transformer = QwenImageTransformer2DModel(**transformer_config) + offload.load_model_data(transformer, transformer_filename) + # transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json") + + text_encoder = offload.fast_load_transformers_model(text_encoder_filename, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json")) + # text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2) + # text_encoder.to(torch.float16) + # offload.save_model(text_encoder, "text_encoder_quanto_fp16.safetensors", do_quantize= True) + + vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json")) + + self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer) + self.vae=vae + self.text_encoder=text_encoder + self.tokenizer=tokenizer + self.transformer=transformer + + def generate( + self, + seed: int | None = None, + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + n_prompt = None, + sampling_steps: int = 20, + input_ref_images = None, + width= 832, + height=480, + guide_scale: float = 4, + fit_into_canvas = None, + callback = None, + loras_slists = None, + batch_size = 1, + video_prompt_type = "", + VAE_tile_size = None, + joint_pass = True, + sample_solver='default', + **bbargs + ): + # Generate with different aspect ratios + aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472) + } + + + if sample_solver =='lightning': + scheduler_config = { + "base_image_seq_len": 256, + "base_shift": math.log(3), # We use shift=3 in distillation + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": math.log(3), # We use shift=3 in distillation + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, # set shift_terminal to None + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, + } + else: + scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": 0.9, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.02, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False + } + + self.scheduler=FlowMatchEulerDiscreteScheduler(**scheduler_config) + self.pipeline.scheduler = self.scheduler + if VAE_tile_size is not None: + self.vae.use_tiling = VAE_tile_size[0] + self.vae.tile_latent_min_height = VAE_tile_size[1] + self.vae.tile_latent_min_width = VAE_tile_size[1] + + + self.vae.enable_slicing() + # width, height = aspect_ratios["16:9"] + + if n_prompt is None or len(n_prompt) == 0: + n_prompt= "text, watermark, copyright, blurry, low resolution" + + image = self.pipeline( + prompt=input_prompt, + negative_prompt=n_prompt, + width=width, + height=height, + num_inference_steps=sampling_steps, + num_images_per_prompt = batch_size, + true_cfg_scale=guide_scale, + callback = callback, + pipeline=self, + loras_slists=loras_slists, + joint_pass = joint_pass, + generator=torch.Generator(device="cuda").manual_seed(seed) + ) + if image is None: return None + return image.transpose(0, 1) + diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py new file mode 100644 index 0000000..042032d --- /dev/null +++ b/models/qwen/transformer_qwenimage.py @@ -0,0 +1,598 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm +from shared.attention import pay_attention + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(1024) + neg_index = torch.arange(1024).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # 是否使用 scale rope + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + frame, height, width = video_fhw + rope_key = f"{frame}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs = self.rope_cache[rope_key] + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + + return vid_freqs, txt_freqs + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + dtype = joint_query.dtype + qkv_list = [joint_query, joint_key, joint_value ] + joint_query = joint_key = joint_value = None + joint_hidden_states = pay_attention(qkv_list) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QwenImageTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + + def preprocess_loras(self, model_type, sd): + + first = next(iter(sd), None) + if first == None: + return sd + + new_sd = {} + for k,v in sd.items(): + k = k.replace(".lora.", ".lora_") + new_sd[k] = v + sd = new_sd + + if first.startswith("transformer_blocks"): + new_sd = {} + for k,v in sd.items(): + if k.startswith("transformer_blocks"): + k = "diffusion_model." + k + new_sd[k] = v + sd = new_sd + return sd + else: + return sd + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.guidance_embeds = guidance_embeds + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states_list = None, + encoder_hidden_states_mask_list = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens_list = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + callback= None, + pipeline =None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + hidden_states_list = [hidden_states if i == 0 else hidden_states.clone() for i, _ in enumerate(encoder_hidden_states_list)] + + new_encoder_hidden_states_list = [] + for encoder_hidden_states in encoder_hidden_states_list: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + new_encoder_hidden_states_list.append(encoder_hidden_states) + encoder_hidden_states_list = new_encoder_hidden_states_list + new_encoder_hidden_states_list = encoder_hidden_states = None + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb_list = [ self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for txt_seq_lens in txt_seq_lens_list] + + hidden_states = None + + for index_block, block in enumerate(self.transformer_blocks): + if callback != None: + callback(-1, None, False, True) + if pipeline._interrupt: + return [None] * len(hidden_states_list) + for hidden_states, encoder_hidden_states, encoder_hidden_states_mask, image_rotary_emb in zip(hidden_states_list, encoder_hidden_states_list, encoder_hidden_states_mask_list, image_rotary_emb_list): + encoder_hidden_states[...], hidden_states[...] = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + output_list = [] + for i in range(len(hidden_states_list)): + hidden_states = self.norm_out(hidden_states_list[i], temb) + hidden_states_list[i] = None + output_list.append(self.proj_out(hidden_states)) + + return output_list diff --git a/wan/__init__.py b/models/wan/__init__.py similarity index 50% rename from wan/__init__.py rename to models/wan/__init__.py index 1688425..fe3be71 100644 --- a/wan/__init__.py +++ b/models/wan/__init__.py @@ -1,3 +1,4 @@ from . import configs, distributed, modules from .any2video import WanAny2V -from .diffusion_forcing import DTT2V \ No newline at end of file +from .diffusion_forcing import DTT2V +from . import wan_handler, df_handler diff --git a/wan/any2video.py b/models/wan/any2video.py similarity index 91% rename from wan/any2video.py rename to models/wan/any2video.py index bea70c5..c879701 100644 --- a/wan/any2video.py +++ b/models/wan/any2video.py @@ -22,14 +22,16 @@ from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE +from .modules.vae2_2 import Wan2_2_VAE + from .modules.clip import CLIPModel -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, +from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.vace_preprocessor import VaceVideoProcessor -from wan.utils.basic_flowmatch import FlowMatchScheduler -from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions +from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .modules.posemb_layers import get_rotary_pos_embed +from shared.utils.vace_preprocessor import VaceVideoProcessor +from shared.utils.basic_flowmatch import FlowMatchScheduler +from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask from mmgp import safetensors2 @@ -85,24 +87,33 @@ class WanAny2V: dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + tokenizer_path=os.path.join(checkpoint_dir, "umt5-xxl"), shard_fn= None) # base_model_type = "i2v2_2" - if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2"]: + if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]: self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, checkpoint_path=os.path.join(checkpoint_dir , config.clip_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer)) + tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) - self.vae_stride = config.vae_stride + + if base_model_type in ["ti2v_2_2"]: + self.vae_stride = (4, 16, 16) + vae_checkpoint = "Wan2.2_VAE.safetensors" + vae = Wan2_2_VAE + else: + self.vae_stride = config.vae_stride + vae_checkpoint = "Wan2.1_VAE.safetensors" + vae = WanVAE self.patch_size = config.patch_size - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, - device=self.device) + self.vae = vae( + vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, + device="cpu") + self.vae.device = self.device # config_filename= "configs/t2v_1.3B.json" # import json @@ -115,7 +126,11 @@ class WanAny2V: # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename - if self.transformer_switch: + source = model_def.get("source", None) + + if source is not None: + self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + elif self.transformer_switch: shared_modules= {} self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) @@ -138,6 +153,10 @@ class WanAny2V: self.model.eval().requires_grad_(False) if self.model2 is not None: self.model2.eval().requires_grad_(False) + if not source is None: + from wgp import save_model + save_model(self.model, model_type, dtype, None) + if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) @@ -233,7 +252,7 @@ class WanAny2V: return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): - from wan.utils.utils import save_image + from shared.utils.utils import save_image ref_width, ref_height = ref_img.size if (ref_height, ref_width) == image_size and outpainting_dims == None: ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) @@ -407,7 +426,7 @@ class WanAny2V: color_correction_strength = 1, prefix_frames_count = 0, image_mode = 0, - + window_no = 0, **bbargs ): @@ -441,6 +460,7 @@ class WanAny2V: sigmas=sampling_sigmas) else: raise NotImplementedError(f"Unsupported Scheduler {sample_solver}") + original_timesteps = timesteps seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) @@ -449,7 +469,6 @@ class WanAny2V: color_reference_frame = None if self._interrupt: return None - # Text Encoder if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -476,16 +495,17 @@ class WanAny2V: vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] - multitalk = model_type in ["multitalk", "vace_multitalk_14B"] + multitalk = model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] recam = model_type in ["recam_1.3B"] - + ti2v = model_type in ["ti2v_2_2"] + start_step_no = 0 ref_images_count = 0 trim_frames = 0 extended_overlapped_latents = None - + timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 # image2video - if model_type in ["i2v", "i2v_2_2", "fantasy", "multitalk", "flf2v_720p"]: + if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False if image_start is None: _ , preframes_count, height, width = input_video.shape @@ -583,13 +603,12 @@ class WanAny2V: if recam: # should be be in fact in input_frames since it is control video not a video to be extended target_camera = model_mode - width = input_video.shape[2] - height = input_video.shape[1] + height,width = input_video.shape[-2:] input_video = input_video.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device) + source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) del input_video # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding + from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) kwargs['cam_emb'] = cam_emb @@ -597,7 +616,7 @@ class WanAny2V: # Video 2 Video if denoising_strength < 1. and input_frames != None: height, width = input_frames.shape[-2:] - source_latents = self.vae.encode([input_frames])[0] + source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) injection_denoising_step = 0 inject_from_start = False if input_frames != None and denoising_strength < 1 : @@ -610,7 +629,7 @@ class WanAny2V: if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) latent_keep_frames = [] - if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0: + if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0: inject_from_start = True if len(keep_frames_parsed) >0 : if overlapped_frames_num > 0: keep_frames_parsed = [True] * overlapped_frames_num + keep_frames_parsed @@ -619,6 +638,7 @@ class WanAny2V: latent_keep_frames.append(all(keep_frames_parsed[i:i+4])) else: timesteps = timesteps[injection_denoising_step:] + start_step_no = injection_denoising_step if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] injection_denoising_step = 0 @@ -632,6 +652,14 @@ class WanAny2V: ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 trim_frames = input_ref_images.shape[1] + if ti2v: + if input_video is None: + height, width = (height // 32) * 32, (width // 32) * 32 + else: + height, width = input_video.shape[-2:] + source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) + timestep_injection = True + # Vace if vace : # vace context encode @@ -671,7 +699,7 @@ class WanAny2V: target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) if multitalk and audio_proj != None: - from wan.multitalk.multitalk import get_target_masks + from .multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None @@ -695,16 +723,17 @@ class WanAny2V: kwargs["freqs"] = freqs # Steps Skipping - cache_type = self.model.enable_cache - if cache_type != None: + skip_steps_cache = self.model.cache + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type x_count = 3 if phantom or fantasy or multitalk else 2 - self.model.previous_residual = [None] * x_count + skip_steps_cache.previous_residual = [None] * x_count if cache_type == "tea": - self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) + self.model.compute_teacache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier) else: - self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) - self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count - self.model.one_for_all = x_count > 2 + self.model.compute_magcache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count + skip_steps_cache.one_for_all = x_count > 2 if callback != None: callback(-1, None, True) @@ -717,13 +746,14 @@ class WanAny2V: # init denoising updated_num_steps= len(timesteps) if callback != None: - from wan.utils.loras_mutipliers import update_loras_slists + from shared.utils.loras_mutipliers import update_loras_slists model_switch_step = updated_num_steps for i, t in enumerate(timesteps): if t <= switch_threshold: model_switch_step = i break update_loras_slists(self.model, loras_slists, updated_num_steps, model_switch_step= model_switch_step) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, model_switch_step= model_switch_step) callback(-1, None, True, override_num_inference_steps = updated_num_steps) if sample_scheduler != None: @@ -748,7 +778,13 @@ class WanAny2V: offload.set_step_no_for_lora(trans, i) timestep = torch.stack([t]) - kwargs.update({"t": timestep, "current_step": i}) + + if timestep_injection: + latents[:, :, :source_latents.shape[2]] = source_latents + timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) + timestep[:source_latents.shape[2]] = 0 + + kwargs.update({"t": timestep, "current_step": start_step_no + i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: @@ -756,14 +792,14 @@ class WanAny2V: noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: new_latents = latents.clone() - new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0) + new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] latents = new_latents new_latents = None else: - latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0) + latents = noise * sigma + (1 - sigma) * source_latents noise = None if extended_overlapped_latents != None: @@ -775,7 +811,7 @@ class WanAny2V: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!! + latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) else: latent_model_input = latents @@ -887,6 +923,9 @@ class WanAny2V: callback(i, latents_preview[0], False) latents_preview = None + if timestep_injection: + latents[:, :, :source_latents.shape[2]] = source_latents + if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: @@ -903,7 +942,7 @@ class WanAny2V: videos = torch.cat([video[:,:1] for video in videos], dim=1) if len(videos) > 1 else videos[0][:,:1] else: videos = videos[0] # return only first video - if color_correction_strength > 0 and prefix_frames_count > 0: + if color_correction_strength > 0 and (prefix_frames_count > 0 and window_no > 1 or prefix_frames_count > 1 and window_no == 1): if vace and False: # videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "progressive_blend").squeeze(0) videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0) @@ -923,8 +962,5 @@ class WanAny2V: setattr(target, "vace", module ) delattr(model, "vace_blocks") -def query_model_def(model_type, model_def): - if "URLs2" in model_def: - return { "no_steps_skipping":True} - else: - return None \ No newline at end of file + + diff --git a/wan/camera_extrinsics.json b/models/wan/camera_extrinsics.json similarity index 100% rename from wan/camera_extrinsics.json rename to models/wan/camera_extrinsics.json diff --git a/wan/configs/__init__.py b/models/wan/configs/__init__.py similarity index 100% rename from wan/configs/__init__.py rename to models/wan/configs/__init__.py diff --git a/wan/configs/shared_config.py b/models/wan/configs/shared_config.py similarity index 100% rename from wan/configs/shared_config.py rename to models/wan/configs/shared_config.py diff --git a/wan/configs/wan_i2v_14B.py b/models/wan/configs/wan_i2v_14B.py similarity index 95% rename from wan/configs/wan_i2v_14B.py rename to models/wan/configs/wan_i2v_14B.py index 7812c92..623a51f 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/models/wan/configs/wan_i2v_14B.py @@ -10,7 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -i2v_14B.t5_tokenizer = 'google/umt5-xxl' +i2v_14B.t5_tokenizer = 'umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' diff --git a/wan/configs/wan_t2v_14B.py b/models/wan/configs/wan_t2v_14B.py similarity index 94% rename from wan/configs/wan_t2v_14B.py rename to models/wan/configs/wan_t2v_14B.py index 9d0ee69..f422d1f 100644 --- a/wan/configs/wan_t2v_14B.py +++ b/models/wan/configs/wan_t2v_14B.py @@ -10,7 +10,7 @@ t2v_14B.update(wan_shared_cfg) # t5 t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_14B.t5_tokenizer = 'google/umt5-xxl' +t2v_14B.t5_tokenizer = 'umt5-xxl' # vae t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' diff --git a/wan/configs/wan_t2v_1_3B.py b/models/wan/configs/wan_t2v_1_3B.py similarity index 94% rename from wan/configs/wan_t2v_1_3B.py rename to models/wan/configs/wan_t2v_1_3B.py index ea9502b..ac23bff 100644 --- a/wan/configs/wan_t2v_1_3B.py +++ b/models/wan/configs/wan_t2v_1_3B.py @@ -10,7 +10,7 @@ t2v_1_3B.update(wan_shared_cfg) # t5 t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' +t2v_1_3B.t5_tokenizer = 'umt5-xxl' # vae t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py new file mode 100644 index 0000000..bf7c266 --- /dev/null +++ b/models/wan/df_handler.py @@ -0,0 +1,97 @@ +import torch + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + if base_model_type == "sky_df_1.3B": + coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + else: + coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + + skip_steps_cache.coefficients = coefficients + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + if base_model_type in ["sky_df_14B"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] =fps + extra_model_def["frames_minimum"] = 17 + extra_model_def["frames_steps"] = 20 + extra_model_def["sliding_window"] = True + extra_model_def["skip_layer_guidance"] = True + extra_model_def["tea_cache"] = True + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["sky_df_1.3B", "sky_df_14B"] + + + @staticmethod + def query_family_maps(): + models_eqv_map = { + "sky_df_1.3B" : "sky_df_14B", + } + + models_comp_map = { + "sky_df_14B": ["sky_df_1.3B"], + } + return models_eqv_map, models_comp_map + + + + @staticmethod + def query_model_family(): + return "wan" + + @staticmethod + def query_family_infos(): + return {} + + + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + from .wan_handler import family_handler + return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + from .configs import WAN_CONFIGS + from .wan_handler import family_handler + cfg = WAN_CONFIGS['t2v-14B'] + from . import DTT2V + wan_model = DTT2V( + config=cfg, + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } + return wan_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "guidance_scale": 6.0, + "flow_shift": 8, + "sliding_window_discard_last_frames" : 0, + "resolution": "1280x720" if "720" in base_model_type else "960x544", + "sliding_window_size" : 121 if "720" in base_model_type else 97, + "RIFLEx_setting": 2, + "guidance_scale": 6, + "flow_shift": 8, + }) \ No newline at end of file diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py new file mode 100644 index 0000000..753fd45 --- /dev/null +++ b/models/wan/diffusion_forcing copy.py @@ -0,0 +1,479 @@ +import math +import os +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import logging +import numpy as np +import torch +from diffusers.image_processor import PipelineImageInput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from tqdm import tqdm +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from wan.modules.posemb_layers import get_rotary_pos_embed +from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +class DTT2V: + + + def __init__( + self, + config, + checkpoint_dir, + rank=0, + model_filename = None, + text_encoder_filename = None, + quantizeTransformer = False, + dtype = torch.bfloat16, + ): + self.device = torch.device(f"cuda") + self.config = config + self.rank = rank + self.dtype = dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=text_encoder_filename, + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn= None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating WanModel from {model_filename}") + from mmgp import offload + + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json") + # offload.load_model_data(self.model, "recam.ckpt") + # self.model.cpu() + # offload.save_model(self.model, "recam.safetensors") + if self.dtype == torch.float16 and not "fp16" in model_filename: + self.model.to(self.dtype) + # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) + if self.dtype == torch.float16: + self.vae.model.to(self.dtype) + self.model.eval().requires_grad_(False) + + self.scheduler = FlowUniPCMultistepScheduler() + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1 + + def encode_image( + self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # prefix_video + prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) + prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) + if prefix_video.dtype == torch.uint8: + prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 + prefix_video = prefix_video.to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + return prefix_video, predix_video_latent_length + + def prepare_latents( + self, + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ) -> torch.Tensor: + return randn_tensor(shape, generator, device=device, dtype=dtype) + + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @torch.no_grad() + def generate( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = "", + image: PipelineImageInput = None, + height: int = 480, + width: int = 832, + num_frames: int = 97, + num_inference_steps: int = 50, + shift: float = 1.0, + guidance_scale: float = 5.0, + seed: float = 0.0, + overlap_history: int = 17, + addnoise_condition: int = 0, + base_num_frames: int = 97, + ar_step: int = 5, + causal_block_size: int = 1, + causal_attention: bool = False, + fps: int = 24, + VAE_tile_size = 0, + joint_pass = False, + callback = None, + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + # if base_num_frames > base_num_frames: + # causal_block_size = 0 + self._guidance_scale = guidance_scale + + i2v_extra_kwrags = {} + prefix_video = None + predix_video_latent_length = 0 + if image: + frame_width, frame_height = image.size + scale = min(height / frame_height, width / frame_width) + height = (int(frame_height * scale) // 16) * 16 + width = (int(frame_width * scale) // 16) * 16 + + prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size) + + latent_length = (num_frames - 1) // 4 + 1 + latent_height = height // 8 + latent_width = width // 8 + + prompt_embeds = self.text_encoder([prompt], self.device) + prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds] + if self.do_classifier_free_guidance: + negative_prompt_embeds = self.text_encoder([negative_prompt], self.device) + negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds] + + + + self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + init_timesteps = self.scheduler.timesteps + fps_embeds = [fps] * prompt_embeds[0].shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + transformer_dtype = self.dtype + # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: + # short video generation + latent_shape = [16, latent_length, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=torch.float32, device=self.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + ) + sample_schedulers = [] + for _ in range(latent_length): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + + if callback != None: + callback(-1, None, True) + + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + # "causal_block_size" : causal_block_size, + "callback" : callback, + "pipeline" : self + } + kwrags.update(i2v_extra_kwrags) + + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=prompt_embeds, + context2=negative_prompt_embeds, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=negative_prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0], False) + + x0 = latents[0].unsqueeze(0) + videos = self.vae.decode(x0, tile_size= VAE_tile_size) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + videos = [video.cpu().numpy().astype(np.uint8) for video in videos] + return videos + else: + # long video generation + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + overlap_history_frames = (overlap_history - 1) // 4 + 1 + n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + print(f"n_iter:{n_iter}") + output_video = None + for i in range(n_iter): + if output_video is not None: # i !=0 + prefix_video = output_video[:, -overlap_history:].to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = latent_length - finished_frame_num + base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) + else: # i == 0 + base_num_frames_iter = base_num_frames + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=torch.float32, device=self.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + init_timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + if callback != None: + callback(-1, None, True) + + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + "causal_block_size" : causal_block_size, + "causal_attention" : causal_attention, + "callback" : callback, + "pipeline" : self + } + kwrags.update(i2v_extra_kwrags) + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=prompt_embeds, + context2=negative_prompt_embeds, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=negative_prompt_embeds, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0].squeeze(0), False) + + x0 = latents[0].unsqueeze(0) + videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + return output_video diff --git a/wan/diffusion_forcing.py b/models/wan/diffusion_forcing.py similarity index 92% rename from wan/diffusion_forcing.py rename to models/wan/diffusion_forcing.py index 6960bda..5d5e42f 100644 --- a/wan/diffusion_forcing.py +++ b/models/wan/diffusion_forcing.py @@ -14,12 +14,12 @@ from tqdm import tqdm from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from wan.modules.posemb_layers import get_rotary_pos_embed -from wan.utils.utils import calculate_new_dimensions -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, +from .modules.posemb_layers import get_rotary_pos_embed +from shared.utils.utils import calculate_new_dimensions +from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.utils.loras_mutipliers import update_loras_slists +from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from shared.utils.loras_mutipliers import update_loras_slists class DTT2V: @@ -313,21 +313,24 @@ class DTT2V: if callback != None: update_loras_slists(self.model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) - if self.model.enable_cache == "tea": - x_count = 2 if self.do_classifier_free_guidance else 1 - self.model.previous_residual = [None] * x_count - time_steps_comb = [] - self.model.num_steps = updated_num_steps - for i, timestep_i in enumerate(step_matrix): - valid_interval_start, valid_interval_end = valid_interval[i] - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: - timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise - time_steps_comb.append(timestep) - self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.cache_multiplier) - del time_steps_comb - else: - self.model.enable_cache = None + skip_steps_cache = self.model.cache + if skip_steps_cache != None: + skip_steps_cache.num_steps = updated_num_steps + if skip_steps_cache.cache_type == "tea": + x_count = 2 if self.do_classifier_free_guidance else 1 + skip_steps_cache.previous_residual = [None] * x_count + time_steps_comb = [] + skip_steps_cache.steps = updated_num_steps + for i, timestep_i in enumerate(step_matrix): + valid_interval_start, valid_interval_end = valid_interval[i] + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: + timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise + time_steps_comb.append(timestep) + self.model.compute_teacache_threshold(skip_steps_cache.start_step, time_steps_comb, skip_steps_cache.multiplier) + del time_steps_comb + else: + self.model.cache = None from mmgp import offload freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False) kwrags = { @@ -431,5 +434,3 @@ class DTT2V: return videos -def query_model_def(model_type, model_def): - return None \ No newline at end of file diff --git a/ltx_video/utils/__init__.py b/models/wan/distributed/__init__.py similarity index 100% rename from ltx_video/utils/__init__.py rename to models/wan/distributed/__init__.py diff --git a/wan/distributed/fsdp.py b/models/wan/distributed/fsdp.py similarity index 100% rename from wan/distributed/fsdp.py rename to models/wan/distributed/fsdp.py diff --git a/wan/distributed/xdit_context_parallel.py b/models/wan/distributed/xdit_context_parallel.py similarity index 100% rename from wan/distributed/xdit_context_parallel.py rename to models/wan/distributed/xdit_context_parallel.py diff --git a/wan/fantasytalking/infer.py b/models/wan/fantasytalking/infer.py similarity index 100% rename from wan/fantasytalking/infer.py rename to models/wan/fantasytalking/infer.py diff --git a/wan/fantasytalking/model.py b/models/wan/fantasytalking/model.py similarity index 99% rename from wan/fantasytalking/model.py rename to models/wan/fantasytalking/model.py index 5ec3655..d0eb74d 100644 --- a/wan/fantasytalking/model.py +++ b/models/wan/fantasytalking/model.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from wan.modules.attention import pay_attention +from shared.attention import pay_attention class AudioProjModel(nn.Module): diff --git a/wan/fantasytalking/utils.py b/models/wan/fantasytalking/utils.py similarity index 100% rename from wan/fantasytalking/utils.py rename to models/wan/fantasytalking/utils.py diff --git a/wan/modules/__init__.py b/models/wan/modules/__init__.py similarity index 81% rename from wan/modules/__init__.py rename to models/wan/modules/__init__.py index 38c29ce..56aea65 100644 --- a/wan/modules/__init__.py +++ b/models/wan/modules/__init__.py @@ -1,8 +1,9 @@ -from .attention import pay_attention +from shared.attention import pay_attention from .model import WanModel from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer from .vae import WanVAE +from .vae2_2 import Wan2_2_VAE __all__ = [ 'WanVAE', diff --git a/wan/modules/clip.py b/models/wan/modules/clip.py similarity index 99% rename from wan/modules/clip.py rename to models/wan/modules/clip.py index da91a00..fc29893 100644 --- a/wan/modules/clip.py +++ b/models/wan/modules/clip.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T -from .attention import pay_attention +from shared.attention import pay_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta diff --git a/wan/modules/model.py b/models/wan/modules/model.py similarity index 91% rename from wan/modules/model.py rename to models/wan/modules/model.py index 19eb5c3..d36efc2 100644 --- a/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -12,9 +12,9 @@ from diffusers.models.modeling_utils import ModelMixin import numpy as np from typing import Union,Optional from mmgp import offload -from .attention import pay_attention +from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel -from wan.multitalk.multitalk_utils import get_attn_map_with_target +from ..multitalk.multitalk_utils import get_attn_map_with_target __all__ = ['WanModel'] @@ -150,7 +150,7 @@ class WanLayerNorm(nn.LayerNorm): return x # return super().forward(x).type_as(x) -from wan.modules.posemb_layers import apply_rotary_emb +from .posemb_layers import apply_rotary_emb class WanSelfAttention(nn.Module): @@ -429,7 +429,7 @@ class WanAttentionBlock(nn.Module): self.block_id = block_id if output_dim > 0: - from wan.multitalk.attention import SingleStreamMutiAttention + from ..multitalk.attention import SingleStreamMutiAttention # init audio module self.audio_cross_attn = SingleStreamMutiAttention( dim=dim, @@ -762,15 +762,17 @@ class WanModel(ModelMixin, ConfigMixin): offload.shared_state["_chipmunk_layers"] = None def preprocess_loras(self, model_type, sd): - # new_sd = {} - # for k,v in sd.items(): - # if not k.endswith(".modulation.diff"): - # new_sd[ k] = v - # sd = new_sd + first = next(iter(sd), None) if first == None: return sd - + + # if first.startswith("blocks."): + # new_sd = {} + # for k,v in sd.items(): + # new_sd["diffusion_model." + k] = v + # sd = new_sd + if first.startswith("lora_unet_"): new_sd = {} print("Converting Lora Safetensors format to Lora Diffusers format") @@ -846,7 +848,6 @@ class WanModel(ModelMixin, ConfigMixin): super().__init__() - assert model_type in ['t2v', 'i2v', 'i2v2_2'] self.model_type = model_type self.patch_size = patch_size @@ -893,7 +894,7 @@ class WanModel(ModelMixin, ConfigMixin): # blocks if vace_layers == None: - cross_attn_type = 't2v_cross_attn' if model_type in ['t2v','i2v2_2'] else 'i2v_cross_attn' + cross_attn_type = 't2v_cross_attn' if model_type in ['t2v','i2v2_2', 'ti2v2_2'] else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, block_no =i, output_dim=multitalk_output_dim, norm_input_visual=norm_input_visual) @@ -962,7 +963,7 @@ class WanModel(ModelMixin, ConfigMixin): block.projector.bias = nn.Parameter(torch.zeros(dim)) if fantasytalking_dim > 0: - from wan.fantasytalking.model import WanCrossAttentionProcessor + from ..fantasytalking.model import WanCrossAttentionProcessor for block in self.blocks: block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim) @@ -1009,6 +1010,7 @@ class WanModel(ModelMixin, ConfigMixin): self._lock_dtype = dtype def compute_magcache_threshold(self, start_step, timesteps = None, speed_factor =0): + skips_step_cache = self.cache def nearest_interp(src_array, target_length): src_length = len(src_array) if target_length == 1: return np.array([src_array[-1]]) @@ -1016,13 +1018,14 @@ class WanModel(ModelMixin, ConfigMixin): mapped_indices = np.round(np.arange(target_length) * scale).astype(int) return src_array[mapped_indices] num_inference_steps = len(timesteps) - if len(self.def_mag_ratios) != num_inference_steps*2: - mag_ratio_con = nearest_interp(self.def_mag_ratios[0::2], num_inference_steps) - mag_ratio_ucon = nearest_interp(self.def_mag_ratios[1::2], num_inference_steps) + def_mag_ratios = np.array([1.0]*2+ skips_step_cache.def_mag_ratios) + if len(def_mag_ratios) != num_inference_steps*2: + mag_ratio_con = nearest_interp(def_mag_ratios[0::2], num_inference_steps) + mag_ratio_ucon = nearest_interp(def_mag_ratios[1::2], num_inference_steps) interpolated_mag_ratios = np.concatenate([mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1) - self.mag_ratios = interpolated_mag_ratios + skips_step_cache.mag_ratios = interpolated_mag_ratios else: - self.mag_ratios = self.def_mag_ratios + skips_step_cache.mag_ratios = def_mag_ratios best_deltas = None @@ -1043,12 +1046,12 @@ class WanModel(ModelMixin, ConfigMixin): else: x_should_calc = [] for cur_x_id in range(x_id_max): - cur_mag_ratio = self.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list + cur_mag_ratio = skips_step_cache.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_steps[cur_x_id] += 1 # skip steps plus 1 cur_skip_err = np.abs(1-accumulated_ratio[cur_x_id]) # skip error of current steps accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps - if accumulated_err[cur_x_id] best_diff: break threshold += 0.01 - self.magcache_thresh = best_threshold + skips_step_cache.magcache_thresh = best_threshold print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): + skips_step_cache = self.cache modulation_dtype = self.time_projection[1].weight.dtype - rescale_func = np.poly1d(self.coefficients) + rescale_func = np.poly1d(skips_step_cache.coefficients) e_list = [] for t in timesteps: t = torch.stack([t]) @@ -1111,7 +1115,7 @@ class WanModel(ModelMixin, ConfigMixin): elif diff > best_diff: break threshold += 0.01 - self.rel_l1_thresh = best_threshold + skips_step_cache.rel_l1_thresh = best_threshold print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") # print(f"deltas:{best_deltas}") return best_threshold @@ -1281,72 +1285,73 @@ class WanModel(ModelMixin, ConfigMixin): del c should_calc = True x_should_calc = None - if self.enable_cache != None: - if self.enable_cache == "mag": - if current_step <= self.cache_start_step: + skips_steps_cache = self.cache + if skips_steps_cache != None: + if skips_steps_cache.cache_type == "mag": + if current_step <= skips_steps_cache.start_step: should_calc = True - elif self.one_for_all and x_id != 0: # not joint pass, not main pas, one for all + elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all assert len(x_list) == 1 - should_calc = self.should_calc + should_calc = skips_steps_cache.should_calc else: x_should_calc = [] - for i in range(1 if self.one_for_all else len(x_list)): + for i in range(1 if skips_steps_cache.one_for_all else len(x_list)): cur_x_id = i if joint_pass else x_id - cur_mag_ratio = self.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list - self.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step - self.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 - cur_skip_err = np.abs(1-self.accumulated_ratio[cur_x_id]) # skip error of current steps - self.accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps - if self.accumulated_err[cur_x_id] 0: - return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: - return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] def decode(self, zs, tile_size, any_end_frame = False): + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0: - return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: - return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py new file mode 100644 index 0000000..c1a88f5 --- /dev/null +++ b/models/wan/modules/vae2_2.py @@ -0,0 +1,1211 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + "Wan2_2_VAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + cache_x = None + x = F.pad(x, padding) + try: + out = super().forward(x) + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be " + "used for this decoding (which is very slow). Consider using tiled VAE Decoding.") + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # avoid double padding here + self.dilation, self.groups) + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + return out + raise +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return (F.normalize(x, dim=(1 if self.channel_first else -1)) * + self.scale * self.gamma + self.bias) + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = ( + self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk(3, dim=-1)) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + _offload_hooks = ['encode', 'decode'] + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale = None, any_end_frame = False): + self.clear_cache() + x = patchify(x, patch_size=2) + ## cache + t = x.shape[2] + if any_end_frame: + iter_ = 2 + (t - 2) // 4 + else: + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + out_list = [] + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out_list.append(self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx)) + elif any_end_frame and i== iter_ -1: + out_list.append(self.encoder( + x[:, :, -1:, :, :], + feat_cache= None, + feat_idx=self._enc_conv_idx)) + else: + out_list.append(self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx)) + + self.clear_cache() + out = torch.cat(out_list, 2) + out_list = None + + mu, log_var = self.conv1(out).chunk(2, dim=1) + if scale != None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + return mu + + + def decode(self, z, scale=None, any_end_frame = False): + self.clear_cache() + # z: [b,c,t,h,w] + if scale != None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + out_list = [] + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out_list.append(self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk = True) + ) + elif any_end_frame and i==iter_-1: + out_list.append(self.decoder( + x[:, :, -1:, :, :], + feat_cache=None , + feat_idx=self._conv_idx)) + else: + out_list.append(self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx)) + self.clear_cache() + out = torch.cat(out_list, 2) + out = unpatchify(out, patch_size=2) + return out + + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def spatial_tiled_decode(self, z, scale, tile_size, any_end_frame= False): + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 16) + tile_overlap_factor = 0.25 + + # z: [b,c,t,h,w] + + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) #8 0.75 + blend_extent = int(tile_sample_min_size * tile_overlap_factor) #256 0.25 + row_limit = tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] + decoded = self.decode(tile, any_end_frame= any_end_frame) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + return torch.cat(result_rows, dim=-2) + + + def spatial_tiled_encode(self, x, scale, tile_size, any_end_frame = False) : + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 16) + tile_overlap_factor = 0.25 + + overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor)) + blend_extent = int(tile_latent_min_size * tile_overlap_factor) + row_limit = tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] + tile = self.encode(tile, any_end_frame= any_end_frame) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + mu = torch.cat(result_rows, dim=-2) + + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + + return mu + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + from mmgp import offload + # load checkpoint + logging.info(f"loading {pretrained_path}") + # model.load_state_dict( + # torch.load(pretrained_path, map_location=device), assign=True) + # offload.save_model(model, "Wan_vae_2_2.safetensors") + # model.to(torch.bfloat16) + # offload.save_model(model, "Wan_vae_2_2_bf16.safetensors") + offload.load_model_data(model, pretrained_path.replace(".pth", ".safetensors"), writable_tensors= False) + + return model + + +class Wan2_2_VAE: + + def __init__( + self, + z_dim=48, + c_dim=160, + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dtype=torch.float, + device="cuda", + ): + + self.dtype = dtype + self.device = device + + mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + dtype=dtype, + device=device, + ) + std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], + dtype=dtype, + device=device, + ) + self.scale = [mean, 1.0 / std] + + # init model + self.model = ( + _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + ).eval().requires_grad_(False).to(device)) + + self.model._model_dtype = dtype + + + @staticmethod + def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision): + # VAE Tiling + if vae_config == 0: + if mixed_precision: + device_mem_capacity = device_mem_capacity / 2 + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + if use_vae_config == 1: + VAE_tile_size = 0 + elif use_vae_config == 2: + VAE_tile_size = 256 + else: + VAE_tile_size = 128 + + return VAE_tile_size + + def encode(self, videos, tile_size = 256, any_end_frame = False): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + scale = [u.to(device = self.device) for u in self.scale] + + if tile_size > 0 and False : + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + else: + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + + + def decode(self, zs, tile_size = 256, any_end_frame = False): + scale = [u.to(device = self.device) for u in self.scale] + if tile_size > 0 : + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + else: + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + diff --git a/wan/modules/xlm_roberta.py b/models/wan/modules/xlm_roberta.py similarity index 100% rename from wan/modules/xlm_roberta.py rename to models/wan/modules/xlm_roberta.py diff --git a/wan/multitalk/attention.py b/models/wan/multitalk/attention.py similarity index 99% rename from wan/multitalk/attention.py rename to models/wan/multitalk/attention.py index 12fb317..27d488f 100644 --- a/wan/multitalk/attention.py +++ b/models/wan/multitalk/attention.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from einops import rearrange, repeat from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids -from wan.modules.attention import pay_attention +from shared.attention import pay_attention # import xformers.ops diff --git a/wan/multitalk/kokoro/__init__.py b/models/wan/multitalk/kokoro/__init__.py similarity index 100% rename from wan/multitalk/kokoro/__init__.py rename to models/wan/multitalk/kokoro/__init__.py diff --git a/wan/multitalk/kokoro/__main__.py b/models/wan/multitalk/kokoro/__main__.py similarity index 100% rename from wan/multitalk/kokoro/__main__.py rename to models/wan/multitalk/kokoro/__main__.py diff --git a/wan/multitalk/kokoro/custom_stft.py b/models/wan/multitalk/kokoro/custom_stft.py similarity index 100% rename from wan/multitalk/kokoro/custom_stft.py rename to models/wan/multitalk/kokoro/custom_stft.py diff --git a/wan/multitalk/kokoro/istftnet.py b/models/wan/multitalk/kokoro/istftnet.py similarity index 100% rename from wan/multitalk/kokoro/istftnet.py rename to models/wan/multitalk/kokoro/istftnet.py diff --git a/wan/multitalk/kokoro/model.py b/models/wan/multitalk/kokoro/model.py similarity index 100% rename from wan/multitalk/kokoro/model.py rename to models/wan/multitalk/kokoro/model.py diff --git a/wan/multitalk/kokoro/modules.py b/models/wan/multitalk/kokoro/modules.py similarity index 100% rename from wan/multitalk/kokoro/modules.py rename to models/wan/multitalk/kokoro/modules.py diff --git a/wan/multitalk/kokoro/pipeline.py b/models/wan/multitalk/kokoro/pipeline.py similarity index 100% rename from wan/multitalk/kokoro/pipeline.py rename to models/wan/multitalk/kokoro/pipeline.py diff --git a/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py similarity index 98% rename from wan/multitalk/multitalk.py rename to models/wan/multitalk/multitalk.py index 56ba16b..f4538ed 100644 --- a/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -7,10 +7,7 @@ import subprocess import torchvision.transforms as transforms import torch.nn.functional as F import torch.nn as nn -import wan -from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS -from wan.utils.utils import cache_image, cache_video, str2bool -# from wan.utils.multitalk_utils import save_video_ffmpeg +# from shared.utils.multitalk_utils import save_video_ffmpeg # from .kokoro import KPipeline from transformers import Wav2Vec2FeatureExtractor from .wav2vec2 import Wav2Vec2Model diff --git a/wan/multitalk/multitalk_model.py b/models/wan/multitalk/multitalk_model.py similarity index 100% rename from wan/multitalk/multitalk_model.py rename to models/wan/multitalk/multitalk_model.py diff --git a/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py similarity index 100% rename from wan/multitalk/multitalk_utils.py rename to models/wan/multitalk/multitalk_utils.py diff --git a/wan/multitalk/torch_utils.py b/models/wan/multitalk/torch_utils.py similarity index 100% rename from wan/multitalk/torch_utils.py rename to models/wan/multitalk/torch_utils.py diff --git a/wan/multitalk/wav2vec2.py b/models/wan/multitalk/wav2vec2.py similarity index 98% rename from wan/multitalk/wav2vec2.py rename to models/wan/multitalk/wav2vec2.py index 5ec9c2b..9ab590c 100644 --- a/wan/multitalk/wav2vec2.py +++ b/models/wan/multitalk/wav2vec2.py @@ -20,7 +20,7 @@ class Wav2Vec2Model(Wav2Vec2Model): output_hidden_states=None, return_dict=None, ): - self.config.output_attentions = True + # self.config.output_attentions = True output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py new file mode 100644 index 0000000..8af9458 --- /dev/null +++ b/models/wan/text2video fuse attempt.py @@ -0,0 +1,698 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial +from mmgp import offload +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm +from PIL import Image +import torchvision.transforms.functional as TF +import torch.nn.functional as F +from .distributed.fsdp import shard_model +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.modules.posemb_layers import get_rotary_pos_embed +from .utils.vace_preprocessor import VaceVideoProcessor + + +def optimized_scale(positive_flat, negative_flat): + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + + +class WanT2V: + + def __init__( + self, + config, + checkpoint_dir, + rank=0, + model_filename = None, + text_encoder_filename = None, + quantizeTransformer = False, + dtype = torch.bfloat16 + ): + self.device = torch.device(f"cuda") + self.config = config + self.rank = rank + self.dtype = dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=text_encoder_filename, + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn= None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating WanModel from {model_filename}") + from mmgp import offload + + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) + # offload.load_model_data(self.model, "recam.ckpt") + # self.model.cpu() + # offload.save_model(self.model, "recam.safetensors") + if self.dtype == torch.float16 and not "fp16" in model_filename: + self.model.to(self.dtype) + # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) + if self.dtype == torch.float16: + self.vae.model.to(self.dtype) + self.model.eval().requires_grad_(False) + + + self.sample_neg_prompt = config.sample_neg_prompt + + if "Vace" in model_filename: + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480*832, + max_area=480*832, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + self.adapt_vace_model() + + self.scheduler = FlowUniPCMultistepScheduler() + + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = self.vae.encode(frames, tile_size = tile_size) + else: + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = self.vae.encode(inactive, tile_size = tile_size) + reactive = self.vae.encode(reactive, tile_size = tile_size) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + else: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // self.vae_stride[0]) + height = 2 * (int(height) // (self.vae_stride[1] * 2)) + width = 2 * (int(width) // (self.vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, self.vae_stride[1], width, self.vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + self.vae_stride[1] * self.vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): + image_sizes = [] + trim_video = len(keep_frames) + + for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): + prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] + num_frames = total_frames - prepend_count + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) + # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) + src_video_shape = src_video[i].shape + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) + else: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) + src_video_shape = src_video[i].shape + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + image_sizes.append(src_video[i].shape[2:]) + for k, keep in enumerate(keep_frames): + if not keep: + src_video[i][:, k:k+1] = 0 + src_mask[i][:, k:k+1] = 1 + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, tile_size= 0 ): + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return self.vae.decode(trimed_zs, tile_size= tile_size) + + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + def generate(self, + input_prompt, + input_frames= None, + input_masks = None, + input_ref_images = None, + source_video=None, + target_camera=None, + context_scale=1.0, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + callback = None, + enable_RIFLEx = None, + VAE_tile_size = 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, + cfg_star_switch = True, + cfg_zero_step = 5, + ): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + frame_num = max(17, frame_num) # must match causal_block_size for value of 5 + frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) + num_frames = frame_num + addnoise_condition = 20 + causal_attention = True + fps = 16 + ar_step = 5 + + + + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if target_camera != None: + size = (source_video.shape[2], source_video.shape[1]) + source_video = source_video.to(dtype=self.dtype , device=self.device) + source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) + source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) + del source_video + # Process target camera (recammaster) + from wan.utils.cammmaster_tools import get_camera_embedding + cam_emb = get_camera_embedding(target_camera) + cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) + + if input_frames != None: + # vace context encode + input_frames = [u.to(self.device) for u in input_frames] + input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] + input_masks = [u.to(self.device) for u in input_masks] + + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + else: + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1]) + + context = [u.to(self.dtype) for u in context] + context_null = [u.to(self.dtype) for u in context_null] + + noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] + + # evaluation mode + + # if sample_solver == 'unipc': + # sample_scheduler = FlowUniPCMultistepScheduler( + # num_train_timesteps=self.num_train_timesteps, + # shift=1, + # use_dynamic_shifting=False) + # sample_scheduler.set_timesteps( + # sampling_steps, device=self.device, shift=shift) + # timesteps = sample_scheduler.timesteps + # elif sample_solver == 'dpm++': + # sample_scheduler = FlowDPMSolverMultistepScheduler( + # num_train_timesteps=self.num_train_timesteps, + # shift=1, + # use_dynamic_shifting=False) + # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + # timesteps, _ = retrieve_timesteps( + # sample_scheduler, + # device=self.device, + # sigmas=sampling_sigmas) + # else: + # raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + del noise + batch_size =len(latents) + if target_camera != None: + shape = list(latents[0].shape[1:]) + shape[0] *= 2 + freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) + else: + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) + # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback} + # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} + # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} + + i2v_extra_kwrags = {} + + if target_camera != None: + recam_dict = {'cam_emb': cam_emb} + i2v_extra_kwrags.update(recam_dict) + + if input_frames != None: + vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} + i2v_extra_kwrags.update(vace_dict) + + + latent_length = (num_frames - 1) // 4 + 1 + latent_height = height // 8 + latent_width = width // 8 + if ar_step == 0: + causal_block_size = 1 + fps_embeds = [fps] #* prompt_embeds[0].shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + init_timesteps = self.scheduler.timesteps + base_num_frames_iter = latent_length + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + + prefix_video = None + predix_video_latent_length = 0 + + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + init_timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + + updated_num_steps= len(step_matrix) + + if callback != None: + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + if self.model.enable_teacache: + self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) + # if callback != None: + # callback(-1, None, True) + + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + "causal_block_size" : causal_block_size, + "causal_attention" : causal_attention, + "callback" : callback, + "pipeline" : self, + "current_step" : i, + } + kwrags.update(i2v_extra_kwrags) + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=context, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=context, + context2=context_null, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=context, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=context_null, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=seed_g, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0].squeeze(0), False) + + # for i, t in enumerate(tqdm(timesteps)): + # if target_camera != None: + # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] + # else: + # latent_model_input = latents + # slg_layers_local = None + # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): + # slg_layers_local = slg_layers + # timestep = [t] + # offload.set_step_no_for_lora(self.model, i) + # timestep = torch.stack(timestep) + + # if joint_pass: + # noise_pred_cond, noise_pred_uncond = self.model( + # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) + # if self._interrupt: + # return None + # else: + # noise_pred_cond = self.model( + # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] + # if self._interrupt: + # return None + # noise_pred_uncond = self.model( + # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] + # if self._interrupt: + # return None + + # # del latent_model_input + + # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + # noise_pred_text = noise_pred_cond + # if cfg_star_switch: + # positive_flat = noise_pred_text.view(batch_size, -1) + # negative_flat = noise_pred_uncond.view(batch_size, -1) + + # alpha = optimized_scale(positive_flat,negative_flat) + # alpha = alpha.view(batch_size, 1, 1, 1) + + # if (i <= cfg_zero_step): + # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... + # else: + # noise_pred_uncond *= alpha + # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) + # del noise_pred_uncond + + # temp_x0 = sample_scheduler.step( + # noise_pred[:, :target_shape[1]].unsqueeze(0), + # t, + # latents[0].unsqueeze(0), + # return_dict=False, + # generator=seed_g)[0] + # latents = [temp_x0.squeeze(0)] + # del temp_x0 + + # if callback is not None: + # callback(i, latents[0], False) + + x0 = latents + + if input_frames == None: + videos = self.vae.decode(x0, VAE_tile_size) + else: + videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) + + del latents + del sample_scheduler + + return videos[0] if self.rank == 0 else None + + def adapt_vace_model(self): + model = self.model + modules_dict= { k: m for k, m in model.named_modules()} + for model_layer, vace_layer in model.vace_layers_mapping.items(): + module = modules_dict[f"vace_blocks.{vace_layer}"] + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "vace", module ) + delattr(model, "vace_blocks") + + \ No newline at end of file diff --git a/wan/trajectory_editor/app.py b/models/wan/trajectory_editor/app.py similarity index 100% rename from wan/trajectory_editor/app.py rename to models/wan/trajectory_editor/app.py diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py new file mode 100644 index 0000000..6d222a2 --- /dev/null +++ b/models/wan/wan_handler.py @@ -0,0 +1,261 @@ +import torch +import numpy as np + +def test_class_i2v(base_model_type): + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "i2v_2_2_multitalk" ] #"hunyuan_i2v", + +def test_class_1_3B(base_model_type): + return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + i2v = test_class_i2v(base_model_type) + + resolution = inputs["resolution"] + width, height = resolution.split("x") + pixels = int(width) * int(height) + + if cache_type == "mag": + skip_steps_cache.update({ + "magcache_thresh" : 0, + "magcache_K" : 2, + }) + if base_model_type in ["t2v"] and "URLs2" in model_def: + def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] + elif base_model_type in ["i2v_2_2"]: + def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] + elif base_model_type in ["ti2v_2_2"]: + if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v + def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] + else: # i2v + def_mag_ratios = [0.99512, 0.99559, 0.99559, 0.99561, 0.99595, 0.99577, 0.99512, 0.99512, 0.99546, 0.99534, 0.99543, 0.99531, 0.99496, 0.99491, 0.99504, 0.99499, 0.99444, 0.99449, 0.99481, 0.99481, 0.99435, 0.99435, 0.9943, 0.99431, 0.99411, 0.99406, 0.99373, 0.99376, 0.99413, 0.99405, 0.99363, 0.99359, 0.99335, 0.99331, 0.99244, 0.99243, 0.99229, 0.99229, 0.99239, 0.99236, 0.99163, 0.9916, 0.99149, 0.99151, 0.99191, 0.99192, 0.9898, 0.98981, 0.9899, 0.98987, 0.98849, 0.98849, 0.98846, 0.98846, 0.98861, 0.98861, 0.9874, 0.98738, 0.98588, 0.98589, 0.98539, 0.98534, 0.98444, 0.98439, 0.9831, 0.98309, 0.98119, 0.98118, 0.98001, 0.98, 0.97862, 0.97859, 0.97555, 0.97558, 0.97392, 0.97388, 0.97152, 0.97145, 0.96871, 0.9687, 0.96435, 0.96434, 0.96129, 0.96127, 0.95639, 0.95638, 0.95176, 0.95175, 0.94446, 0.94452, 0.93972, 0.93974, 0.93575, 0.9359, 0.93537, 0.93552, 0.96655, 0.96616] + elif test_class_1_3B(base_model_type): #text 1.3B + def_mag_ratios = [1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]#**(0.5)# In our papaer, we utilize the sqrt to smooth the ratio, which has little impact on the performance and can be deleted. + elif i2v: + if pixels >= 1280*720: + def_mag_ratios = [0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768] + else: + def_mag_ratios = [0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616] + else: # text 14B + def_mag_ratios = [1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189] + skip_steps_cache.def_mag_ratios = def_mag_ratios + else: + if i2v: + if pixels >= 1280*720: + coefficients= [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + else: + coefficients= [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + else: + if test_class_1_3B(base_model_type): + coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + else: + coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + skip_steps_cache.coefficients = coefficients + + @staticmethod + def get_wan_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8") + return text_encoder_filename + + + + @staticmethod + def query_modules_files(): + return { + "vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"], + "vace_1.3B" : ["ckpts/wan2.1_Vace_1_3B_module.safetensors"], + "fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"], + "multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] +} + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + if "URLs2" in model_def: + extra_model_def["no_steps_skipping"] = True + i2v = test_class_i2v(base_model_type) + extra_model_def["i2v_class"] = i2v + extra_model_def["multitalk_class"] = base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] + vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] + extra_model_def["vace_class"] = vace_class + + if base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]: + fps = 25 + elif base_model_type in ["fantasy"]: + fps = 23 + elif base_model_type in ["ti2v_2_2"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] =fps + + if vace_class: + frames_minimum, frames_steps = 17, 4 + else: + frames_minimum, frames_steps = 5, 4 + extra_model_def.update({ + "frames_minimum" : frames_minimum, + "frames_steps" : frames_steps, + "sliding_window" : base_model_type in ["multitalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "guidance_max_phases" : 2, + "skip_layer_guidance" : True, + "cfg_zero" : True, + "cfg_star" : True, + "adaptive_projected_guidance" : True, + "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or "URLs2" in model_def), + "mag_cache" : True, + "sample_solvers":[ + ("unipc", "unipc"), + ("euler", "euler"), + ("dpm++", "dpm++"), + ("flowmatch causvid", "causvid"), ] + }) + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", + "t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "recam_1.3B", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + + + @staticmethod + def query_family_maps(): + + models_eqv_map = { + "flf2v_720p" : "i2v", + "t2v_1.3B" : "t2v", + } + + models_comp_map = { + "vace_14B" : [ "vace_multitalk_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], + "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], + "i2v_2_2" : ["i2v_2_2_multitalk"], + "fantasy": ["multitalk"], + } + return models_eqv_map, models_comp_map + + @staticmethod + def query_model_family(): + return "wan" + + @staticmethod + def query_family_infos(): + return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") } + + @staticmethod + def get_vae_block_size(base_model_type): + return 32 if base_model_type == "ti2v_2_2" else 16 + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + if base_model_type == "ti2v_2_2": return None, None + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan") + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = family_handler.get_wan_text_encoder_filename(text_encoder_quantization) + + download_def = [{ + "repoId" : "DeepBeepMeep/Wan2.1", + "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], + "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] + }] + + if base_model_type == "ti2v_2_2": + download_def += [ { + "repoId" : "DeepBeepMeep/Wan2.2", + "sourceFolderList" : [""], + "fileList" : [ [ "Wan2.2_VAE.safetensors" ] ] + }] + + return download_def + + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + from .configs import WAN_CONFIGS + + if test_class_i2v(base_model_type): + cfg = WAN_CONFIGS['i2v-14B'] + else: + cfg = WAN_CONFIGS['t2v-14B'] + # cfg = WAN_CONFIGS['t2v-1.3B'] + from . import WanAny2V + wan_model = WanAny2V( + config=cfg, + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } + if hasattr(wan_model,"model2") and wan_model.model2 is not None: + pipe["transformer2"] = wan_model.model2 + if hasattr(wan_model, "clip"): + pipe["text_encoder_2"] = wan_model.clip.model + return wan_model, pipe + + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if ui_defaults.get("sample_solver", "") == "": + ui_defaults["sample_solver"] = "unipc" + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "sample_solver": "unipc", + }) + if base_model_type in ["fantasy"]: + ui_defaults.update({ + "audio_guidance_scale": 5.0, + "sliding_window_size": 1, + }) + + elif base_model_type in ["multitalk"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "audio_guidance_scale": 4, + "sliding_window_discard_last_frames" : 4, + "sample_solver" : "euler", + "adaptive_switch" : 1, + }) + + elif base_model_type in ["phantom_1.3B", "phantom_14B"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "remove_background_images_ref": 1, + "video_prompt_type": "I", + # "resolution": "1280x720" + }) + + elif base_model_type in ["vace_14B", "vace_multitalk_14B"]: + ui_defaults.update({ + "sliding_window_discard_last_frames": 0, + }) + + elif base_model_type in ["ti2v_2_2"]: + ui_defaults.update({ + "image_prompt_type": "T", + }) + + \ No newline at end of file diff --git a/postprocessing/mmaudio/data/av_utils.py b/postprocessing/mmaudio/data/av_utils.py index 19776dc..7866d6c 100644 --- a/postprocessing/mmaudio/data/av_utils.py +++ b/postprocessing/mmaudio/data/av_utils.py @@ -131,7 +131,7 @@ from pathlib import Path import torch def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int): - from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = Path(f.name) diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index 67813ee..151d9be 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -21,6 +21,7 @@ from segment_anything.modeling.image_encoder import window_partition, window_unp from .utils.get_default_model import get_matanyone_model from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone +from shared.utils.audio_video import save_video, save_image arg_device = "cuda" arg_sam_model_type="vit_h" @@ -377,14 +378,14 @@ def show_mask(video_state, interactive_state, mask_dropdown): return select_frame -def save_video(frames, output_path, fps): +# def save_video(frames, output_path, fps): - writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) - for frame in frames: - writer.append_data(frame) - writer.close() +# writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) +# for frame in frames: +# writer.append_data(frame) +# writer.close() - return output_path +# return output_path def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -535,20 +536,20 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive file_name= video_state["video_name"] file_name = ".".join(file_name.split(".")[:-1]) - from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" if len(source_audio_tracks) == 0: - foreground_output = save_video(foreground, output_path=output_fg_path , fps=fps) + foreground_output = save_video(foreground,output_fg_path , fps=fps, codec_type= video_output_codec) else: - foreground_output_tmp = save_video(foreground, output_path=output_fg_temp_path , fps=fps) + foreground_output_tmp = save_video(foreground, output_fg_temp_path , fps=fps, codec_type= video_output_codec) combine_video_with_audio_tracks(output_fg_temp_path, source_audio_tracks, output_fg_path, audio_metadata=audio_metadata) cleanup_temp_audio_files(source_audio_tracks) os.remove(foreground_output_tmp) foreground_output = output_fg_path - alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps) + alpha_output = save_video(alpha, "./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps, codec_type= video_output_codec) return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) @@ -745,8 +746,12 @@ def teleport_to_video_tab(tab_state): return gr.Tabs(selected="video_gen") -def display(tabs, tab_state, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): +def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) + global image_output_codec, video_output_codec + + image_output_codec = server_config.get("image_output_codec", None) + video_output_codec = server_config.get("video_output_codec", None) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" diff --git a/requirements.txt b/requirements.txt index 44bda57..a2e1fde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,51 +1,61 @@ -torch>=2.4.0 -torchvision>=0.19.0 -opencv-python>=4.9.0.80 -diffusers>=0.31.0 -transformers==4.51.3 -#transformers==4.46.3 # was needed by llamallava used by i2v hunyuan before patch +# Core AI stack +diffusers==0.34.0 +transformers==4.53.1 tokenizers>=0.20.3 accelerate>=1.1.1 tqdm imageio -easydict -ftfy -dashscope imageio-ffmpeg -# flash_attn -gradio==5.23.0 -numpy>=1.23.5,<2 einops -moviepy==1.0.3 -mmgp==3.5.6 -peft==0.15.0 -mutagen -pydantic==2.10.6 -decord -onnxruntime-gpu -rembg[gpu]==2.0.65 -matplotlib -timm -segment-anything -omegaconf -hydra-core -librosa==0.11.0 -loguru sentencepiece +open_clip_torch>=2.29.0 + +# Video & media +moviepy==1.0.3 av -opencv-python +ffmpeg-python pygame>=2.1.0 sounddevice>=0.4.0 -# rembg==2.0.65 -torchdiffeq >= 0.2.5 -tensordict >= 0.6.1 -open_clip_torch >= 2.29.0 -pyloudnorm -misaki soundfile -ffmpeg-python -pyannote.audio +mutagen +pyloudnorm +librosa==0.11.0 + +# UI & interaction +gradio==5.23.0 +dashscope +loguru + +# Vision & segmentation +opencv-python>=4.9.0.80 +segment-anything +rembg[gpu]==2.0.65 +onnxruntime-gpu +decord +timm + +# Config & orchestration +omegaconf +hydra-core +easydict +pydantic==2.10.6 + +# Math & modeling +torchdiffeq>=0.2.5 +tensordict>=0.6.1 +mmgp==3.5.10 +peft==0.15.0 +matplotlib + +# Utilities +ftfy +piexif pynvml -huggingface_hub[hf_xet] +misaki + +# Optional / commented out +# transformers==4.46.3 # for llamallava pre-patch +# rembg==2.0.65 # non-GPU fallback +# huggingface_hub[hf_xet] # slows down everything # num2words -# spacy \ No newline at end of file +# spacy diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py new file mode 100644 index 0000000..5ec1e59 --- /dev/null +++ b/shared/RGB_factors.py @@ -0,0 +1,213 @@ +# thanks Comfyui for the rgb factors +def get_rgb_factors(model_family): + if model_family == "wan": + latent_channels = 16 + latent_dimensions = 3 + latent_rgb_factors = [ + [-0.1299, -0.1692, 0.2932], + [ 0.0671, 0.0406, 0.0442], + [ 0.3568, 0.2548, 0.1747], + [ 0.0372, 0.2344, 0.1420], + [ 0.0313, 0.0189, -0.0328], + [ 0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [ 0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [ 0.0680, 0.3019, 0.1128], + [ 0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [ 0.0060, -0.0633, 0.0005], + [ 0.3477, 0.2275, 0.2950], + [ 0.1984, 0.0913, 0.1861] + ] + + latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] + + # latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + elif model_family =="flux": + scale_factor = 0.3611 + shift_factor = 0.1159 + latent_rgb_factors =[ + [-0.0346, 0.0244, 0.0681], + [ 0.0034, 0.0210, 0.0687], + [ 0.0275, -0.0668, -0.0433], + [-0.0174, 0.0160, 0.0617], + [ 0.0859, 0.0721, 0.0329], + [ 0.0004, 0.0383, 0.0115], + [ 0.0405, 0.0861, 0.0915], + [-0.0236, -0.0185, -0.0259], + [-0.0245, 0.0250, 0.1180], + [ 0.1008, 0.0755, -0.0421], + [-0.0515, 0.0201, 0.0011], + [ 0.0428, -0.0012, -0.0036], + [ 0.0817, 0.0765, 0.0749], + [-0.1264, -0.0522, -0.1103], + [-0.0280, -0.0881, -0.0499], + [-0.1262, -0.0982, -0.0778] + ] + latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + + elif model_family == "ltxv": + latent_channels = 128 + latent_dimensions = 3 + + latent_rgb_factors = [ + [ 1.1202e-02, -6.3815e-04, -1.0021e-02], + [ 8.6031e-02, 6.5813e-02, 9.5409e-04], + [-1.2576e-02, -7.5734e-03, -4.0528e-03], + [ 9.4063e-03, -2.1688e-03, 2.6093e-03], + [ 3.7636e-03, 1.2765e-02, 9.1548e-03], + [ 2.1024e-02, -5.2973e-03, 3.4373e-03], + [-8.8896e-03, -1.9703e-02, -1.8761e-02], + [-1.3160e-02, -1.0523e-02, 1.9709e-03], + [-1.5152e-03, -6.9891e-03, -7.5810e-03], + [-1.7247e-03, 4.6560e-04, -3.3839e-03], + [ 1.3617e-02, 4.7077e-03, -2.0045e-03], + [ 1.0256e-02, 7.7318e-03, 1.3948e-02], + [-1.6108e-02, -6.2151e-03, 1.1561e-03], + [ 7.3407e-03, 1.5628e-02, 4.4865e-04], + [ 9.5357e-04, -2.9518e-03, -1.4760e-02], + [ 1.9143e-02, 1.0868e-02, 1.2264e-02], + [ 4.4575e-03, 3.6682e-05, -6.8508e-03], + [-4.5681e-04, 3.2570e-03, 7.7929e-03], + [ 3.3902e-02, 3.3405e-02, 3.7454e-02], + [-2.3001e-02, -2.4877e-03, -3.1033e-03], + [ 5.0265e-02, 3.8841e-02, 3.3539e-02], + [-4.1018e-03, -1.1095e-03, 1.5859e-03], + [-1.2689e-01, -1.3107e-01, -2.1005e-01], + [ 2.6276e-02, 1.4189e-02, -3.5963e-03], + [-4.8679e-03, 8.8486e-03, 7.8029e-03], + [-1.6610e-03, -4.8597e-03, -5.2060e-03], + [-2.1010e-03, 2.3610e-03, 9.3796e-03], + [-2.2482e-02, -2.1305e-02, -1.5087e-02], + [-1.5753e-02, -1.0646e-02, -6.5083e-03], + [-4.6975e-03, 5.0288e-03, -6.7390e-03], + [ 1.1951e-02, 2.0712e-02, 1.6191e-02], + [-6.3704e-03, -8.4827e-03, -9.5483e-03], + [ 7.2610e-03, -9.9326e-03, -2.2978e-02], + [-9.1904e-04, 6.2882e-03, 9.5720e-03], + [-3.7178e-02, -3.7123e-02, -5.6713e-02], + [-1.3373e-01, -1.0720e-01, -5.3801e-02], + [-5.3702e-03, 8.1256e-03, 8.8397e-03], + [-1.5247e-01, -2.1437e-01, -2.1843e-01], + [ 3.1441e-02, 7.0335e-03, -9.7541e-03], + [ 2.1528e-03, -8.9817e-03, -2.1023e-02], + [ 3.8461e-03, -5.8957e-03, -1.5014e-02], + [-4.3470e-03, -1.2940e-02, -1.5972e-02], + [-5.4781e-03, -1.0842e-02, -3.0204e-03], + [-6.5347e-03, 3.0806e-03, -1.0163e-02], + [-5.0414e-03, -7.1503e-03, -8.9686e-04], + [-8.5851e-03, -2.4351e-03, 1.0674e-03], + [-9.0016e-03, -9.6493e-03, 1.5692e-03], + [ 5.0914e-03, 1.2099e-02, 1.9968e-02], + [ 1.3758e-02, 1.1669e-02, 8.1958e-03], + [-1.0518e-02, -1.1575e-02, -4.1307e-03], + [-2.8410e-02, -3.1266e-02, -2.2149e-02], + [ 2.9336e-03, 3.6511e-02, 1.8717e-02], + [-1.6703e-02, -1.6696e-02, -4.4529e-03], + [ 4.8818e-02, 4.0063e-02, 8.7410e-03], + [-1.5066e-02, -5.7328e-04, 2.9785e-03], + [-1.7613e-02, -8.1034e-03, 1.3086e-02], + [-9.2633e-03, 1.0803e-02, -6.3489e-03], + [ 3.0851e-03, 4.7750e-04, 1.2347e-02], + [-2.2785e-02, -2.3043e-02, -2.6005e-02], + [-2.4787e-02, -1.5389e-02, -2.2104e-02], + [-2.3572e-02, 1.0544e-03, 1.2361e-02], + [-7.8915e-03, -1.2271e-03, -6.0968e-03], + [-1.1478e-02, -1.2543e-03, 6.2679e-03], + [-5.4229e-02, 2.6644e-02, 6.3394e-03], + [ 4.4216e-03, -7.3338e-03, -1.0464e-02], + [-4.5013e-03, 1.6082e-03, 1.4420e-02], + [ 1.3673e-02, 8.8877e-03, 4.1253e-03], + [-1.0145e-02, 9.0072e-03, 1.5695e-02], + [-5.6234e-03, 1.1847e-03, 8.1261e-03], + [-3.7171e-03, -5.3538e-03, 1.2590e-03], + [ 2.9476e-02, 2.1424e-02, 3.0424e-02], + [-3.4925e-02, -2.4340e-02, -2.5316e-02], + [-3.4127e-02, -2.2406e-02, -1.0589e-02], + [-1.7342e-02, -1.3249e-02, -1.0719e-02], + [-2.1478e-03, -8.6051e-03, -2.9878e-03], + [ 1.2089e-03, -4.2391e-03, -6.8569e-03], + [ 9.0411e-04, -6.6886e-03, -6.7547e-05], + [ 1.6048e-02, -1.0057e-02, -2.8929e-02], + [ 1.2290e-03, 1.0163e-02, 1.8861e-02], + [ 1.7264e-02, 2.7257e-04, 1.3785e-02], + [-1.3482e-02, -3.6427e-03, 6.7481e-04], + [ 4.6782e-03, -5.2423e-03, 2.4467e-03], + [-5.9113e-03, -6.2244e-03, -1.8162e-03], + [ 1.5496e-02, 1.4582e-02, 1.9514e-03], + [ 7.4958e-03, 1.5886e-03, -8.2305e-03], + [ 1.9086e-02, 1.6360e-03, -3.9674e-03], + [-5.7021e-03, -2.7307e-03, -4.1066e-03], + [ 1.7450e-03, 1.4602e-02, 2.5794e-02], + [-8.2788e-04, 2.2902e-03, 4.5161e-03], + [ 1.1632e-02, 8.9193e-03, -7.2813e-03], + [ 7.5721e-03, 2.6784e-03, 1.1393e-02], + [ 5.1939e-03, 3.6903e-03, 1.4049e-02], + [-1.8383e-02, -2.2529e-02, -2.4477e-02], + [ 5.8842e-04, -5.7874e-03, -1.4770e-02], + [-1.6125e-02, -8.6101e-03, -1.4533e-02], + [ 2.0540e-02, 2.0729e-02, 6.4338e-03], + [ 3.3587e-03, -1.1226e-02, -1.6444e-02], + [-1.4742e-03, -1.0489e-02, 1.7097e-03], + [ 2.8130e-02, 2.3546e-02, 3.2791e-02], + [-1.8532e-02, -1.2842e-02, -8.7756e-03], + [-8.0533e-03, -1.0771e-02, -1.7536e-02], + [-3.9009e-03, 1.6150e-02, 3.3359e-02], + [-7.4554e-03, -1.4154e-02, -6.1910e-03], + [ 3.4734e-03, -1.1370e-02, -1.0581e-02], + [ 1.1476e-02, 3.9281e-03, 2.8231e-03], + [ 7.1639e-03, -1.4741e-03, -3.8066e-03], + [ 2.2250e-03, -8.7552e-03, -9.5719e-03], + [ 2.4146e-02, 2.1696e-02, 2.8056e-02], + [-5.4365e-03, -2.4291e-02, -1.7802e-02], + [ 7.4263e-03, 1.0510e-02, 1.2705e-02], + [ 6.2669e-03, 6.2658e-03, 1.9211e-02], + [ 1.6378e-02, 9.4933e-03, 6.6971e-03], + [ 1.7173e-02, 2.3601e-02, 2.3296e-02], + [-1.4568e-02, -9.8279e-03, -1.1556e-02], + [ 1.4431e-02, 1.4430e-02, 6.6362e-03], + [-6.8230e-03, 1.8863e-02, 1.4555e-02], + [ 6.1156e-03, 3.4700e-03, -2.6662e-03], + [-2.6983e-03, -5.9402e-03, -9.2276e-03], + [ 1.0235e-02, 7.4173e-03, -7.6243e-03], + [-1.3255e-02, 1.9322e-02, -9.2153e-04], + [ 2.4222e-03, -4.8039e-03, -1.5759e-02], + [ 2.6244e-02, 2.5951e-02, 2.0249e-02], + [ 1.5711e-02, 1.8498e-02, 2.7407e-03], + [-2.1714e-03, 4.7214e-03, -2.2443e-02], + [-7.4747e-03, 7.4166e-03, 1.4430e-02], + [-8.3906e-03, -7.9776e-03, 9.7927e-03], + [ 3.8321e-02, 9.6622e-03, -1.9268e-02], + [-1.4605e-02, -6.7032e-03, 3.9675e-03] + ] + latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] + + elif model_family == "hunyuan": + latent_channels = 16 + latent_dimensions = 3 + scale_factor = 0.476986 + latent_rgb_factors = [ + [-0.0395, -0.0331, 0.0445], + [ 0.0696, 0.0795, 0.0518], + [ 0.0135, -0.0945, -0.0282], + [ 0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], + [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], + [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], + [-0.0799, -0.0208, -0.0375], + [ 0.1166, 0.1627, 0.0962], + [ 0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], + [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], + [ 0.0249, -0.0469, -0.1703] + ] + + latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + else: + latent_rgb_factors_bias = latent_rgb_factors = None + return latent_rgb_factors, latent_rgb_factors_bias \ No newline at end of file diff --git a/wan/modules/attention.py b/shared/attention.py similarity index 100% rename from wan/modules/attention.py rename to shared/attention.py diff --git a/shared/match_archi.py b/shared/match_archi.py new file mode 100644 index 0000000..7d535d5 --- /dev/null +++ b/shared/match_archi.py @@ -0,0 +1,64 @@ +import re + +def match_nvidia_architecture(conditions_dict, architecture): + """ + Match Nvidia architecture against condition dictionary. + + Args: + conditions_dict: dict with condition strings as keys, parameters as values + architecture: int representing architecture (e.g., 89 for Ada Lovelace) + + Returns: + list of matched parameters + + Condition syntax: + - Operators: '<', '>', '<=', '>=', '=' (or no operator for equality) + - OR: '+' between conditions (e.g., '<=50+>89') + - AND: '&' between conditions (e.g., '>=70&<90') + - Examples: + * '<89': architectures below Ada (89) + * '>=75': architectures 75 and above + * '89': exactly Ada architecture + * '<=50+>89': Maxwell (50) and below OR above Ada + * '>=70&<90': Ampere range (70-89) + """ + + def eval_condition(cond, arch): + """Evaluate single condition against architecture""" + cond = cond.strip() + if not cond: + return False + + # Parse operator and value using regex + match = re.match(r'(>=|<=|>|<|=?)(\d+)', cond) + if not match: + return False + + op, val = match.groups() + val = int(val) + + # Handle operators + if op in ('', '='): + return arch == val + elif op == '>=': + return arch >= val + elif op == '<=': + return arch <= val + elif op == '>': + return arch > val + elif op == '<': + return arch < val + return False + + def matches_condition(condition_str, arch): + """Check if architecture matches full condition string""" + # Split by '+' for OR conditions, then by '&' for AND conditions + return any( + all(eval_condition(and_cond, arch) for and_cond in or_cond.split('&')) + for or_cond in condition_str.split('+') + if or_cond.strip() + ) + + # Return all parameters where conditions match + return [params for condition, params in conditions_dict.items() + if matches_condition(condition, architecture)] \ No newline at end of file diff --git a/wan/modules/sage2_core.py b/shared/sage2_core.py similarity index 100% rename from wan/modules/sage2_core.py rename to shared/sage2_core.py diff --git a/wan/utils/__init__.py b/shared/utils/__init__.py similarity index 100% rename from wan/utils/__init__.py rename to shared/utils/__init__.py diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py new file mode 100644 index 0000000..b24530d --- /dev/null +++ b/shared/utils/audio_video.py @@ -0,0 +1,421 @@ +import subprocess +import tempfile, os +import ffmpeg +import torchvision.transforms.functional as TF +import torch.nn.functional as F +import cv2 +import tempfile +import imageio +import binascii +import torchvision +import torch +from PIL import Image +import os.path as osp +import json + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + + +def extract_audio_tracks(source_video, verbose=False, query_only=False): + """ + Extract all audio tracks from a source video into temporary AAC files. + + Returns: + Tuple: + - List of temp file paths for extracted audio tracks + - List of corresponding metadata dicts: + {'codec', 'sample_rate', 'channels', 'duration', 'language'} + where 'duration' is set to container duration (for consistency). + """ + probe = ffmpeg.probe(source_video) + audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] + container_duration = float(probe['format'].get('duration', 0.0)) + + if not audio_streams: + if query_only: return 0 + if verbose: print(f"No audio track found in {source_video}") + return [], [] + + if query_only: + return len(audio_streams) + + if verbose: + print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s") + + file_paths = [] + metadata = [] + + for i, stream in enumerate(audio_streams): + fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') + os.close(fd) + + file_paths.append(temp_path) + metadata.append({ + 'codec': stream.get('codec_name'), + 'sample_rate': int(stream.get('sample_rate', 0)), + 'channels': int(stream.get('channels', 0)), + 'duration': container_duration, + 'language': stream.get('tags', {}).get('language', None) + }) + + ffmpeg.input(source_video).output( + temp_path, + **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'} + ).overwrite_output().run(quiet=not verbose) + + return file_paths, metadata + + + +def combine_and_concatenate_video_with_audio_tracks( + save_path_tmp, video_path, + source_audio_tracks, new_audio_tracks, + source_audio_duration, audio_sampling_rate, + new_audio_from_start=False, + source_audio_metadata=None, + audio_bitrate='128k', + audio_codec='aac', + verbose = False +): + inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1 + metadata_args = [] + sources = source_audio_tracks or [] + news = new_audio_tracks or [] + + duplicate_source = len(sources) == 1 and len(news) > 1 + N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1 + + for i in range(N): + s = (sources[i] if i < len(sources) + else sources[0] if duplicate_source else None) + n = news[i] if len(news) == N else (news[0] if news else None) + + if source_audio_duration == 0: + if n: + inputs += ['-i', n] + filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}]') + idx += 1 + else: + filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}]') + else: + if s: + inputs += ['-i', s] + meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {} + needs_filter = ( + meta.get('codec') != audio_codec or + meta.get('sample_rate') != audio_sampling_rate or + meta.get('channels') != 1 or + meta.get('duration', 0) < source_audio_duration + ) + if needs_filter: + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + else: + filters.append( + f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + if lang := meta.get('language'): + metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}'] + idx += 1 + else: + filters.append( + f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + + if n: + inputs += ['-i', n] + start = '0' if new_audio_from_start else source_audio_duration + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}]') + filters.append(f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}]') + idx += 1 + else: + filters.append(f'[s{i}]apad=pad_dur=100[aout{i}]') + + maps += ['-map', f'[aout{i}]'] + + cmd = ['ffmpeg', '-y', *inputs, + '-filter_complex', ';'.join(filters), # ✅ Only change made + *maps, *metadata_args, + '-c:v', 'copy', + '-c:a', audio_codec, + '-b:a', audio_bitrate, + '-ar', str(audio_sampling_rate), + '-ac', '1', + '-shortest', save_path_tmp] + + if verbose: + print(f"ffmpeg command: {cmd}") + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + raise Exception(f"FFmpeg error: {e.stderr}") + + +def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, + audio_metadata=None, verbose=False): + if not audio_tracks: + if verbose: print("No audio tracks to combine."); return False + + dur = float(next(s for s in ffmpeg.probe(target_video)['streams'] + if s['codec_type'] == 'video')['duration']) + if verbose: print(f"Video duration: {dur:.3f}s") + + cmd = ['ffmpeg', '-y', '-i', target_video] + for path in audio_tracks: + cmd += ['-i', path] + + cmd += ['-map', '0:v'] + for i in range(len(audio_tracks)): + cmd += ['-map', f'{i+1}:a'] + + for i, meta in enumerate(audio_metadata or []): + if (lang := meta.get('language')): + cmd += ['-metadata:s:a:' + str(i), f'language={lang}'] + + cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video] + + result = subprocess.run(cmd, capture_output=not verbose, text=True) + if result.returncode != 0: + raise Exception(f"FFmpeg error:\n{result.stderr}") + if verbose: + print(f"Created {output_video} with {len(audio_tracks)} audio track(s)") + return True + + +def cleanup_temp_audio_files(audio_tracks, verbose=False): + """ + Clean up temporary audio files. + + Args: + audio_tracks: List of audio file paths to delete + verbose: Enable verbose output (default: False) + + Returns: + Number of files successfully deleted + """ + deleted_count = 0 + + for audio_path in audio_tracks: + try: + if os.path.exists(audio_path): + os.unlink(audio_path) + deleted_count += 1 + if verbose: + print(f"Cleaned up {audio_path}") + except PermissionError: + print(f"Warning: Could not delete {audio_path} (file may be in use)") + except Exception as e: + print(f"Warning: Error deleting {audio_path}: {e}") + + if verbose and deleted_count > 0: + print(f"Successfully deleted {deleted_count} temporary audio file(s)") + + return deleted_count + + +def save_video(tensor, + save_file=None, + fps=30, + codec_type='libx264_8', + container='mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + """Save tensor as video with configurable codec and container options.""" + + suffix = f'.{container}' + cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file + if not cache_file.endswith(suffix): + cache_file = osp.splitext(cache_file)[0] + suffix + + # Configure codec parameters + codec_params = _get_codec_params(codec_type, container) + + # Process and save + error = None + for _ in range(retry): + try: + if torch.is_tensor(tensor): + # Preprocess tensor + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + arrays = tensor.numpy() + else: + arrays = tensor + + # Write video (silence ffmpeg logs) + writer = imageio.get_writer(cache_file, fps=fps, ffmpeg_log_level='error', **codec_params) + for frame in arrays: + writer.append_data(frame) + + writer.close() + return cache_file + + except Exception as e: + error = e + print(f"error saving {save_file}: {e}") + + +def _get_codec_params(codec_type, container): + """Get codec parameters based on codec type and container.""" + if codec_type == 'libx264_8': + return {'codec': 'libx264', 'quality': 8, 'pixelformat': 'yuv420p'} + elif codec_type == 'libx264_10': + return {'codec': 'libx264', 'quality': 10, 'pixelformat': 'yuv420p'} + elif codec_type == 'libx265_28': + return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '28', '-x265-params', 'log-level=none','-hide_banner', '-nostats']} + elif codec_type == 'libx265_8': + return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '8', '-x265-params', 'log-level=none','-hide_banner', '-nostats']} + elif codec_type == 'libx264_lossless': + if container == 'mkv': + return {'codec': 'ffv1', 'pixelformat': 'rgb24'} + else: # mp4 + return {'codec': 'libx264', 'output_params': ['-crf', '0'], 'pixelformat': 'yuv444p'} + else: # libx264 + return {'codec': 'libx264', 'pixelformat': 'yuv420p'} + + + + +def save_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + quality='jpeg_95', # 'jpeg_95', 'jpeg_85', 'jpeg_70', 'jpeg_50', 'webp_95', 'webp_85', 'webp_70', 'webp_50', 'png', 'webp_lossless' + retry=5): + """Save tensor as image with configurable format and quality.""" + + # Get format and quality settings + format_info = _get_format_info(quality) + + # Rename file extension to match requested format + save_file = osp.splitext(save_file)[0] + format_info['ext'] + + # Save image + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + + if format_info['use_pil']: + # Use PIL for WebP and advanced options + grid = torchvision.utils.make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range) + # Convert to PIL Image + grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + img = Image.fromarray(grid) + img.save(save_file, **format_info['params']) + else: + # Use torchvision for JPEG and PNG + torchvision.utils.save_image( + tensor, save_file, nrow=nrow, normalize=normalize, + value_range=value_range, **format_info['params'] + ) + break + except Exception as e: + error = e + continue + else: + print(f'cache_image failed, error: {error}', flush=True) + + return save_file + + +def _get_format_info(quality): + """Get format extension and parameters.""" + formats = { + # JPEG with PIL (so 'quality' works) + 'jpeg_95': {'ext': '.jpg', 'params': {'quality': 95}, 'use_pil': True}, + 'jpeg_85': {'ext': '.jpg', 'params': {'quality': 85}, 'use_pil': True}, + 'jpeg_70': {'ext': '.jpg', 'params': {'quality': 70}, 'use_pil': True}, + 'jpeg_50': {'ext': '.jpg', 'params': {'quality': 50}, 'use_pil': True}, + + # PNG with torchvision + 'png': {'ext': '.png', 'params': {}, 'use_pil': False}, + + # WebP with PIL (for quality control) + 'webp_95': {'ext': '.webp', 'params': {'quality': 95}, 'use_pil': True}, + 'webp_85': {'ext': '.webp', 'params': {'quality': 85}, 'use_pil': True}, + 'webp_70': {'ext': '.webp', 'params': {'quality': 70}, 'use_pil': True}, + 'webp_50': {'ext': '.webp', 'params': {'quality': 50}, 'use_pil': True}, + 'webp_lossless': {'ext': '.webp', 'params': {'lossless': True}, 'use_pil': True}, + } + return formats.get(quality, formats['jpeg_95']) + + +from PIL import Image, PngImagePlugin + +def _enc_uc(s): + try: return b"ASCII\0\0\0" + s.encode("ascii") + except UnicodeEncodeError: return b"UNICODE\0" + s.encode("utf-16le") + +def _dec_uc(b): + if not isinstance(b, (bytes, bytearray)): + try: b = bytes(b) + except Exception: return None + if b.startswith(b"ASCII\0\0\0"): return b[8:].decode("ascii", "ignore") + if b.startswith(b"UNICODE\0"): return b[8:].decode("utf-16le", "ignore") + return b.decode("utf-8", "ignore") + +def save_image_metadata(image_path, metadata_dict, **save_kwargs): + try: + j = json.dumps(metadata_dict, ensure_ascii=False) + ext = os.path.splitext(image_path)[1].lower() + with Image.open(image_path) as im: + if ext == ".png": + pi = PngImagePlugin.PngInfo(); pi.add_text("comment", j) + im.save(image_path, pnginfo=pi, **save_kwargs); return True + if ext in (".jpg", ".jpeg"): + im.save(image_path, comment=j.encode("utf-8"), **save_kwargs); return True + if ext == ".webp": + import piexif + exif = {"0th":{}, "Exif":{piexif.ExifIFD.UserComment:_enc_uc(j)}, "GPS":{}, "1st":{}, "thumbnail":None} + im.save(image_path, format="WEBP", exif=piexif.dump(exif), **save_kwargs); return True + raise ValueError("Unsupported format") + except Exception as e: + print(f"Error saving metadata: {e}"); return False + +def read_image_metadata(image_path): + try: + ext = os.path.splitext(image_path)[1].lower() + with Image.open(image_path) as im: + if ext == ".png": + val = (getattr(im, "text", {}) or {}).get("comment") or im.info.get("comment") + return json.loads(val) if val else None + if ext in (".jpg", ".jpeg"): + val = im.info.get("comment") + if isinstance(val, (bytes, bytearray)): val = val.decode("utf-8", "ignore") + if val: + try: return json.loads(val) + except Exception: pass + exif = getattr(im, "getexif", lambda: None)() + if exif: + uc = exif.get(37510) # UserComment + s = _dec_uc(uc) if uc else None + if s: + try: return json.loads(s) + except Exception: pass + return None + if ext == ".webp": + exif_bytes = Image.open(image_path).info.get("exif") + if not exif_bytes: return None + import piexif + uc = piexif.load(exif_bytes).get("Exif", {}).get(piexif.ExifIFD.UserComment) + s = _dec_uc(uc) if uc else None + return json.loads(s) if s else None + return None + except Exception as e: + print(f"Error reading metadata: {e}"); return None \ No newline at end of file diff --git a/wan/utils/basic_flowmatch.py b/shared/utils/basic_flowmatch.py similarity index 100% rename from wan/utils/basic_flowmatch.py rename to shared/utils/basic_flowmatch.py diff --git a/wan/utils/cammmaster_tools.py b/shared/utils/cammmaster_tools.py similarity index 97% rename from wan/utils/cammmaster_tools.py rename to shared/utils/cammmaster_tools.py index 6e255a0..b93ebba 100644 --- a/wan/utils/cammmaster_tools.py +++ b/shared/utils/cammmaster_tools.py @@ -40,7 +40,7 @@ def get_relative_pose(cam_params): def get_camera_embedding(cam_type, num_frames=81): # load camera - tgt_camera_path = "wan/camera_extrinsics.json" + tgt_camera_path = "models/wan/camera_extrinsics.json" with open(tgt_camera_path, 'r') as file: cam_data = json.load(file) diff --git a/wan/utils/fm_solvers.py b/shared/utils/fm_solvers.py similarity index 100% rename from wan/utils/fm_solvers.py rename to shared/utils/fm_solvers.py diff --git a/wan/utils/fm_solvers_unipc.py b/shared/utils/fm_solvers_unipc.py similarity index 100% rename from wan/utils/fm_solvers_unipc.py rename to shared/utils/fm_solvers_unipc.py diff --git a/wan/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py similarity index 99% rename from wan/utils/loras_mutipliers.py rename to shared/utils/loras_mutipliers.py index 8698898..6d2acca 100644 --- a/wan/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -44,7 +44,7 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me slists_dict["phase2"] = phase2 = [1.] * nb_loras if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: - list_mult_choices_list = preparse_loras_multipliers(loras_multipliers) + list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras] for i, mult in enumerate(list_mult_choices_list): current_phase = phase1 if isinstance(mult, str): diff --git a/wan/utils/motion.py b/shared/utils/motion.py similarity index 100% rename from wan/utils/motion.py rename to shared/utils/motion.py diff --git a/wan/utils/notification_sound.py b/shared/utils/notification_sound.py similarity index 99% rename from wan/utils/notification_sound.py rename to shared/utils/notification_sound.py index c9a42a6..64ffd8f 100644 --- a/wan/utils/notification_sound.py +++ b/shared/utils/notification_sound.py @@ -8,7 +8,8 @@ import sys import threading import time import numpy as np - +import os +os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" def generate_notification_beep(volume=50, sample_rate=44100): """Generate pleasant C major chord notification sound""" diff --git a/wan/utils/prompt_extend.py b/shared/utils/prompt_extend.py similarity index 100% rename from wan/utils/prompt_extend.py rename to shared/utils/prompt_extend.py diff --git a/wan/utils/prompt_parser.py b/shared/utils/prompt_parser.py similarity index 100% rename from wan/utils/prompt_parser.py rename to shared/utils/prompt_parser.py diff --git a/wan/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py similarity index 100% rename from wan/utils/qwen_vl_utils.py rename to shared/utils/qwen_vl_utils.py diff --git a/wan/utils/stats.py b/shared/utils/stats.py similarity index 100% rename from wan/utils/stats.py rename to shared/utils/stats.py diff --git a/wan/utils/thread_utils.py b/shared/utils/thread_utils.py similarity index 100% rename from wan/utils/thread_utils.py rename to shared/utils/thread_utils.py diff --git a/wan/utils/utils.py b/shared/utils/utils.py similarity index 58% rename from wan/utils/utils.py rename to shared/utils/utils.py index 041fdea..9f347e8 100644 --- a/wan/utils/utils.py +++ b/shared/utils/utils.py @@ -1,6 +1,5 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import binascii import os import os.path as osp import torchvision.transforms.functional as TF @@ -10,7 +9,6 @@ import tempfile import imageio import torch import decord -import torchvision from PIL import Image import numpy as np from rembg import remove, new_session @@ -21,8 +19,6 @@ import tempfile import subprocess import json -__all__ = ['cache_video', 'cache_image', 'str2bool'] - from PIL import Image @@ -218,84 +214,6 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg return output_list -def rand_name(length=8, suffix=''): - name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') - if suffix: - if not suffix.startswith('.'): - suffix = '.' + suffix - name += suffix - return name - - -def cache_video(tensor, - save_file=None, - fps=30, - suffix='.mp4', - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): - # cache file - cache_file = osp.join('/tmp', rand_name( - suffix=suffix)) if save_file is None else save_file - - # save to cache - error = None - for _ in range(retry): - try: - # preprocess - tensor = tensor.clamp(min(value_range), max(value_range)) - tensor = torch.stack([ - torchvision.utils.make_grid( - u, nrow=nrow, normalize=normalize, value_range=value_range) - for u in tensor.unbind(2) - ], - dim=1).permute(1, 2, 3, 0) - tensor = (tensor * 255).type(torch.uint8).cpu() - - # write video - writer = imageio.get_writer( - cache_file, fps=fps, codec='libx264', quality=8) - for frame in tensor.numpy(): - writer.append_data(frame) - writer.close() - return cache_file - except Exception as e: - error = e - continue - else: - print(f'cache_video failed, error: {error}', flush=True) - return None - - -def cache_image(tensor, - save_file, - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): - # cache file - suffix = osp.splitext(save_file)[1] - if suffix.lower() not in [ - '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' - ]: - suffix = '.png' - - # save to cache - error = None - for _ in range(retry): - try: - tensor = tensor.clamp(min(value_range), max(value_range)) - torchvision.utils.save_image( - tensor, - save_file, - nrow=nrow, - normalize=normalize, - value_range=value_range) - return save_file - except Exception as e: - error = e - continue def str2bool(v): @@ -435,212 +353,3 @@ def create_progress_hook(filename): return hook -import tempfile, os -import ffmpeg - -def extract_audio_tracks(source_video, verbose=False, query_only=False): - """ - Extract all audio tracks from a source video into temporary AAC files. - - Returns: - Tuple: - - List of temp file paths for extracted audio tracks - - List of corresponding metadata dicts: - {'codec', 'sample_rate', 'channels', 'duration', 'language'} - where 'duration' is set to container duration (for consistency). - """ - probe = ffmpeg.probe(source_video) - audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] - container_duration = float(probe['format'].get('duration', 0.0)) - - if not audio_streams: - if query_only: return 0 - if verbose: print(f"No audio track found in {source_video}") - return [], [] - - if query_only: - return len(audio_streams) - - if verbose: - print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s") - - file_paths = [] - metadata = [] - - for i, stream in enumerate(audio_streams): - fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') - os.close(fd) - - file_paths.append(temp_path) - metadata.append({ - 'codec': stream.get('codec_name'), - 'sample_rate': int(stream.get('sample_rate', 0)), - 'channels': int(stream.get('channels', 0)), - 'duration': container_duration, - 'language': stream.get('tags', {}).get('language', None) - }) - - ffmpeg.input(source_video).output( - temp_path, - **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'} - ).overwrite_output().run(quiet=not verbose) - - return file_paths, metadata - - -import subprocess - -import subprocess - -def combine_and_concatenate_video_with_audio_tracks( - save_path_tmp, video_path, - source_audio_tracks, new_audio_tracks, - source_audio_duration, audio_sampling_rate, - new_audio_from_start=False, - source_audio_metadata=None, - audio_bitrate='128k', - audio_codec='aac', - verbose = False -): - inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1 - metadata_args = [] - sources = source_audio_tracks or [] - news = new_audio_tracks or [] - - duplicate_source = len(sources) == 1 and len(news) > 1 - N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1 - - for i in range(N): - s = (sources[i] if i < len(sources) - else sources[0] if duplicate_source else None) - n = news[i] if len(news) == N else (news[0] if news else None) - - if source_audio_duration == 0: - if n: - inputs += ['-i', n] - filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}]') - idx += 1 - else: - filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}]') - else: - if s: - inputs += ['-i', s] - meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {} - needs_filter = ( - meta.get('codec') != audio_codec or - meta.get('sample_rate') != audio_sampling_rate or - meta.get('channels') != 1 or - meta.get('duration', 0) < source_audio_duration - ) - if needs_filter: - filters.append( - f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' - f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') - else: - filters.append( - f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') - if lang := meta.get('language'): - metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}'] - idx += 1 - else: - filters.append( - f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') - - if n: - inputs += ['-i', n] - start = '0' if new_audio_from_start else source_audio_duration - filters.append( - f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' - f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}]') - filters.append(f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}]') - idx += 1 - else: - filters.append(f'[s{i}]apad=pad_dur=100[aout{i}]') - - maps += ['-map', f'[aout{i}]'] - - cmd = ['ffmpeg', '-y', *inputs, - '-filter_complex', ';'.join(filters), # ✅ Only change made - *maps, *metadata_args, - '-c:v', 'copy', - '-c:a', audio_codec, - '-b:a', audio_bitrate, - '-ar', str(audio_sampling_rate), - '-ac', '1', - '-shortest', save_path_tmp] - - if verbose: - print(f"ffmpeg command: {cmd}") - try: - subprocess.run(cmd, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - raise Exception(f"FFmpeg error: {e.stderr}") - - -import ffmpeg - - -import subprocess -import ffmpeg - -def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, - audio_metadata=None, verbose=False): - if not audio_tracks: - if verbose: print("No audio tracks to combine."); return False - - dur = float(next(s for s in ffmpeg.probe(target_video)['streams'] - if s['codec_type'] == 'video')['duration']) - if verbose: print(f"Video duration: {dur:.3f}s") - - cmd = ['ffmpeg', '-y', '-i', target_video] - for path in audio_tracks: - cmd += ['-i', path] - - cmd += ['-map', '0:v'] - for i in range(len(audio_tracks)): - cmd += ['-map', f'{i+1}:a'] - - for i, meta in enumerate(audio_metadata or []): - if (lang := meta.get('language')): - cmd += ['-metadata:s:a:' + str(i), f'language={lang}'] - - cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video] - - result = subprocess.run(cmd, capture_output=not verbose, text=True) - if result.returncode != 0: - raise Exception(f"FFmpeg error:\n{result.stderr}") - if verbose: - print(f"Created {output_video} with {len(audio_tracks)} audio track(s)") - return True - - -def cleanup_temp_audio_files(audio_tracks, verbose=False): - """ - Clean up temporary audio files. - - Args: - audio_tracks: List of audio file paths to delete - verbose: Enable verbose output (default: False) - - Returns: - Number of files successfully deleted - """ - deleted_count = 0 - - for audio_path in audio_tracks: - try: - if os.path.exists(audio_path): - os.unlink(audio_path) - deleted_count += 1 - if verbose: - print(f"Cleaned up {audio_path}") - except PermissionError: - print(f"Warning: Could not delete {audio_path} (file may be in use)") - except Exception as e: - print(f"Warning: Error deleting {audio_path}: {e}") - - if verbose and deleted_count > 0: - print(f"Successfully deleted {deleted_count} temporary audio file(s)") - - return deleted_count - diff --git a/wan/utils/vace_preprocessor.py b/shared/utils/vace_preprocessor.py similarity index 99% rename from wan/utils/vace_preprocessor.py rename to shared/utils/vace_preprocessor.py index 7fdb8c9..947767e 100644 --- a/wan/utils/vace_preprocessor.py +++ b/shared/utils/vace_preprocessor.py @@ -184,7 +184,7 @@ class VaceVideoProcessor(object): def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0): - from wan.utils.utils import resample + from shared.utils.utils import resample target_fps = self.max_fps diff --git a/wan/distributed/__init__.py b/wan/distributed/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/wan/trajectory_editor/templates/index.html b/wan/trajectory_editor/templates/index.html deleted file mode 100644 index 2ac8c78..0000000 --- a/wan/trajectory_editor/templates/index.html +++ /dev/null @@ -1,571 +0,0 @@ - - - - - - - Track Point Editor - - - -

Track Point Editor

- - -
- - -
- - - - - -
-
- - - -
-
- - - -
-
- - -
-
- -
-
- - - -
- -
- - - - -
-
- - - - - diff --git a/wgp.py b/wgp.py index 904b8a9..ce55548 100644 --- a/wgp.py +++ b/wgp.py @@ -13,14 +13,15 @@ from datetime import datetime import gradio as gr import random import json -import wan -from wan.utils import notification_sound -from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers -from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video -from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, calculate_new_dimensions - -from wan.modules.attention import get_attention_modes, get_supported_attention_modes +import numpy as np +import importlib +from shared.utils import notification_sound +from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers +from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions +from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image +from shared.utils.audio_video import save_image_metadata, read_image_metadata +from shared.match_archi import match_nvidia_architecture +from shared.attention import get_attention_modes, get_supported_attention_modes from huggingface_hub import hf_hub_download, snapshot_download import torch import gc @@ -29,7 +30,7 @@ import math import typing import asyncio import inspect -from wan.utils import prompt_parser +from shared.utils import prompt_parser import base64 import io from PIL import Image @@ -45,13 +46,16 @@ from preprocessing.matanyone import app as matanyone_app from tqdm import tqdm import requests +# import torch._dynamo as dynamo +# dynamo.config.recompile_limit = 2000 # default is 256 +# dynamo.config.accumulated_recompile_limit = 2000 # or whatever limit you want global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.6" -WanGP_version = "7.61" +target_mmgp_version = "3.5.10" +WanGP_version = "7.7777" settings_version = 2.23 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -118,7 +122,7 @@ def pil_to_base64_uri(pil_image, format="png", quality=75): return None if isinstance(pil_image, str): - from wan.utils.utils import get_video_frame + from shared.utils.utils import get_video_frame pil_image = get_video_frame(pil_image, 0) buffer = io.BytesIO() @@ -178,7 +182,7 @@ def process_prompt_and_add_tasks(state, model_choice): return get_queue_table(queue) model_def = get_model_def(model_type) image_outputs = inputs["image_mode"] == 1 - no_steps_skipping = model_def.get("no_steps_skipping", False) + any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -302,17 +306,11 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info(f"Error parsing Loras Multipliers: {errors}") return - if no_steps_skipping: skip_steps_cache_type = "" - if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0: - gr.Info("Steps skipping is not yet supported if Switch Threshold is not null") - return + if not any_steps_skipping: skip_steps_cache_type = "" if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") return if skip_steps_cache_type == "mag": - if model_type in ["sky_df_1.3B", "sky_df_14B"]: - gr.Info("Mag Cache is not supported with Diffusion Forcing") - return if num_inference_steps > 50: gr.Info("Mag Cache maximum number of steps is 50") return @@ -321,7 +319,7 @@ def process_prompt_and_add_tasks(state, model_choice): audio_prompt_type = "" if "B" in audio_prompt_type or "X" in audio_prompt_type: - from wan.multitalk.multitalk import parse_speakers_locations + from models.wan.multitalk.multitalk import parse_speakers_locations speakers_bboxes, error = parse_speakers_locations(speakers_locations) if len(error) > 0: gr.Info(error) @@ -1398,6 +1396,12 @@ def _parse_args(): help="Path to a directory that contains flux images Loras" ) + parser.add_argument( + "--lora-dir-qwen", + type=str, + default="loras_qwen", + help="Path to a directory that contains qwen images Loras" + ) parser.add_argument( "--check-loras", @@ -1507,7 +1511,7 @@ def _parse_args(): "--perc-reserved-mem-max", type=float, default=0, - help="% of RAM allocated to Reserved RAM" + help="percent of RAM allocated to Reserved RAM" ) @@ -1617,7 +1621,8 @@ def _parse_args(): def get_lora_dir(model_type): model_family = get_model_family(model_type) - i2v = test_class_i2v(model_type) and not get_base_model_type(model_type) == "i2v_2_2" + base_model_type = get_base_model_type(model_type) + i2v = test_class_i2v(model_type) and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] if model_family == "wan": lora_dir =args.lora_dir if i2v and len(lora_dir)==0: @@ -1630,6 +1635,10 @@ def get_lora_dir(model_type): lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") if os.path.isdir(lora_dir_1_3B ): return lora_dir_1_3B + elif base_model_type == "ti2v_2_2": + lora_dir_5B = os.path.join(root_lora_dir, "5B") + if os.path.isdir(lora_dir_5B ): + return lora_dir_5B else: lora_dir_14B = os.path.join(root_lora_dir, "14B") if os.path.isdir(lora_dir_14B ): @@ -1644,6 +1653,8 @@ def get_lora_dir(model_type): return args.lora_dir_hunyuan_i2v else: return args.lora_dir_hunyuan + elif model_family =="qwen": + return args.lora_dir_qwen else: raise Exception("loras unknown") @@ -1651,8 +1662,8 @@ attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() -major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) -if major < 8: +gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) +if gpu_major < 8: print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels") bfloat16_supported = False else: @@ -1703,10 +1714,10 @@ if not Path(server_config_filename).is_file(): "transformer_types": [], "transformer_quantization": "int8", "text_encoder_quantization" : "int8", - "save_path": "outputs", #os.path.join(os.getcwd(), + "save_path": "outputs", + "image_save_path": "outputs", "compile" : "", "metadata_type": "metadata", - "default_ui": "t2v", "boost" : 1, "clear_file_list" : 5, "vae_config": 0, @@ -1734,24 +1745,16 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) -families_infos = {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2"), "ltxv":(10, "LTX Video"), "hunyuan":(20, "Hunyuan Video"), "flux":(30, "Flux 1"), "unknown": (100, "Unknown") } +for f, s in [("ckpts/Florence2/modeling_florence2.py", 127287)]: + try: + if os.path.isfile(f) and os.path.getsize(f) == s: + print(f"Removing old version of model '{f}'. A new version of this model will be downloaded next time you use it.") + os.remove(f) + except: pass models_def = {} +family_handlers = ["models.wan.wan_handler", "models.wan.df_handler", "models.hyvideo.hunyuan_handler", "models.ltx_video.ltxv_handler", "models.flux.flux_handler", "models.qwen.qwen_handler"] -modules_files = { - "vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"], - "vace_1.3B" : ["ckpts/wan2.1_Vace_1_3B_module.safetensors"], - "fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"], - "multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] -} - -# architectures supported -base_types = ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", - "t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", "sky_df_1.3B", "sky_df_14B", - "i2v", "i2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp", "ltxv_13B", - "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar", "flux" - ] # only needed for imported old settings files model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", @@ -1762,36 +1765,50 @@ model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", " "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit", "hunyuan_avatar" : "hunyuan_video_avatar" } + +def map_family_handlers(family_handlers): + base_types_handlers, families_infos, models_eqv_map, models_comp_map = {}, {"unknown": (100, "Unknown")}, {}, {} + for path in family_handlers: + handler = importlib.import_module(path).family_handler + for model_type in handler.query_supported_types(): + if model_type in base_types_handlers: + prev = base_types_handlers[model_type].__name__ + raise Exception(f"Model type {model_type} supported by {prev} and {handler.__name__}") + base_types_handlers[model_type] = handler + families_infos.update(handler.query_family_infos()) + eq_map, comp_map = handler.query_family_maps() + models_eqv_map.update(eq_map); models_comp_map.update(comp_map) + return base_types_handlers, families_infos, models_eqv_map, models_comp_map + +model_types_handlers, families_infos, models_eqv_map, models_comp_map = map_family_handlers(family_handlers) + def get_base_model_type(model_type): model_def = get_model_def(model_type) if model_def == None: - return model_type if model_type in base_types else None + return model_type if model_type in model_types_handlers else None # return model_type else: return model_def["architecture"] +def get_model_handler(model_type): + base_model_type = get_base_model_type(model_type) + if base_model_type is None: + raise Exception(f"Unknown model type {model_type}") + model_handler = model_types_handlers.get(base_model_type, None) + if model_handler is None: + raise Exception(f"No model handler found for base model type {base_model_type}") + return model_handler + def are_model_types_compatible(imported_model_type, current_model_type): imported_base_model_type = get_base_model_type(imported_model_type) curent_base_model_type = get_base_model_type(current_model_type) if imported_base_model_type == curent_base_model_type: return True - eqv_map = { - "flf2v_720p" : "i2v", - "t2v_1.3B" : "t2v", - "sky_df_1.3B" : "sky_df_14B", - } - if imported_base_model_type in eqv_map: - imported_base_model_type = eqv_map[imported_base_model_type] - comp_map = { - "vace_14B" : [ "vace_multitalk_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], - "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], - "fantasy": ["multitalk"], - "sky_df_14B": ["sky_df_1.3B"], - "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], - } - comp_list= comp_map.get(imported_base_model_type, None) + if imported_base_model_type in models_eqv_map: + imported_base_model_type = models_eqv_map[imported_base_model_type] + + comp_list= models_comp_map.get(imported_base_model_type, None) if comp_list == None: return False return curent_base_model_type in comp_list @@ -1817,51 +1834,32 @@ def get_model_family(model_type, for_ui = False): model_family = model_def.get("group", None) if model_family is not None and model_family in families_infos: return model_family - - if "hunyuan" in base_model_type : - return "hunyuan" - elif "ltxv" in base_model_type: - return "ltxv" - elif "flux" in base_model_type: - return "flux" - else: - return "wan" + handler = model_types_handlers.get(base_model_type, None) + if handler is None: + return "unknown" + return handler.query_model_family() -def test_class_i2v(model_type): - model_type = get_base_model_type(model_type) - return model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk" ] #"hunyuan_i2v", +def test_class_i2v(model_type): + model_def = get_model_def(model_type) + return model_def.get("i2v_class", False) def test_vace_module(model_type): - model_type = get_base_model_type(model_type) - return model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] + model_def = get_model_def(model_type) + return model_def.get("vace_class", False) def test_any_sliding_window(model_type): - model_type = get_base_model_type(model_type) - return test_vace_module(model_type) or model_type in ["sky_df_1.3B", "sky_df_14B", "ltxv_13B", "multitalk", "t2v", "fantasy"] or test_class_i2v(model_type) + model_def = get_model_def(model_type) + return model_def.get("sliding_window", False) def get_model_min_frames_and_step(model_type): - model_type = get_base_model_type(model_type) - if model_type in ["sky_df_14B"]: - return 17, 20 - elif model_type in ["ltxv_13B"]: - return 17, 8 - elif test_vace_module(model_type): - return 17, 4 - else: - return 5, 4 - + mode_def = get_model_def(model_type) + frames_minimum = mode_def.get("frames_minimum", 5) + frames_steps = mode_def.get("frames_steps", 4) + return frames_minimum, frames_steps + def get_model_fps(model_type): - model_type = get_base_model_type(model_type) - if model_type in ["hunyuan_avatar", "hunyuan_custom_audio", "multitalk", "vace_multitalk_14B"]: - fps = 25 - elif model_type in ["sky_df_14B", "hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: - fps = 24 - elif model_type in ["fantasy"]: - fps = 23 - elif model_type in ["ltxv_13B"]: - fps = 30 - else: - fps = 16 + mode_def = get_model_def(model_type) + fps= mode_def.get("fps", 16) return fps def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): @@ -1912,10 +1910,13 @@ def get_model_recursive_prop(model_type, prop = "URLs", return_list = True, sta raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, submodel_no = 1, stack=[]): - if is_module: - choices = modules_files.get(model_type, None) - if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]): + if module_type is not None: + base_model_type = get_base_model_type(model_type) + model_type_handler = model_types_handlers[base_model_type] + modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} + choices = modules_files.get(module_type, None) + if choices == None: raise Exception(f"Invalid Module Id '{module_type}'") else: key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" @@ -1971,11 +1972,11 @@ def get_settings_file_name(model_type): return os.path.join(args.settings, model_type + "_settings.json") def fix_settings(model_type, ui_defaults): - if model_type == None: return + if model_type is None: return - video_settings_version = ui_defaults.get("settings_version", 0) + settings_version = ui_defaults.get("settings_version", 0) model_def = get_model_def(model_type) - model_type = get_base_model_type(model_type) + base_model_type = get_base_model_type(model_type) prompts = ui_defaults.get("prompts", "") if len(prompts) > 0: @@ -1986,43 +1987,43 @@ def fix_settings(model_type, ui_defaults): image_prompt_type = "S" if image_prompt_type == 0 else "SE" # if model_type == "flf2v_720p" and not "E" in image_prompt_type: # image_prompt_type = "SE" - if video_settings_version <= 2: + if settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type if "lset_name" in ui_defaults: del ui_defaults["lset_name"] audio_prompt_type = ui_defaults.get("audio_prompt_type", None) - if video_settings_version < 2.2: - if not model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "ltxv_13B"]: + if settings_version < 2.2: + if not base_model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "ltxv_13B"]: for p in ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]: if p in ui_defaults: del ui_defaults[p] if audio_prompt_type == None : - if any_audio_track(model_type): + if any_audio_track(base_model_type): audio_prompt_type ="A" ui_defaults["audio_prompt_type"] = audio_prompt_type video_prompt_type = ui_defaults.get("video_prompt_type", "") any_reference_image = model_def.get("reference_image", False) - if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: if not "I" in video_prompt_type: # workaround for settings corruption video_prompt_type += "I" - if model_type in ["hunyuan"]: + if base_model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") - if model_type in ["flux"] and video_settings_version < 2.23: + if base_model_type in ["flux"] and settings_version < 2.23: video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI") remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1) - if video_settings_version < 2.22: + if settings_version < 2.22: if "I" in video_prompt_type: if remove_background_images_ref == 2: video_prompt_type = video_prompt_type.replace("I", "KI") if remove_background_images_ref != 0: remove_background_images_ref = 1 - if model_type in ["hunyuan_avatar"]: remove_background_images_ref = 0 + if base_model_type in ["hunyuan_avatar"]: remove_background_images_ref = 0 ui_defaults["remove_background_images_ref"] = remove_background_images_ref ui_defaults["video_prompt_type"] = video_prompt_type @@ -2043,6 +2044,10 @@ def fix_settings(model_type, ui_defaults): del ui_defaults["tea_cache_start_step_perc"] ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc + model_handler = get_model_handler(base_model_type) + if hasattr(model_handler, "fix_settings"): + model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: @@ -2063,7 +2068,6 @@ def get_default_settings(model_type): "repeat_generation": 1, "multi_images_gen_type": 0, "guidance_scale": 5.0, - "embedded_guidance_scale" : 6.0, "flow_shift": 7.0 if not "720" in base_model_type and i2v else 5.0, "negative_prompt": "", "activated_loras": [], @@ -2076,87 +2080,8 @@ def get_default_settings(model_type): "slg_start_perc": 10, "slg_end_perc": 90 } - if base_model_type in ["fantasy"]: - ui_defaults["audio_guidance_scale"] = 5.0 - elif base_model_type in ["multitalk"]: - ui_defaults.update({ - "guidance_scale": 5.0, - "flow_shift": 7, # 11 for 720p - "audio_guidance_scale": 4, - "sliding_window_discard_last_frames" : 4, - "sample_solver" : "euler", - "adaptive_switch" : 1, - }) - - elif base_model_type in ["hunyuan","hunyuan_i2v"]: - ui_defaults.update({ - "guidance_scale": 7.0, - }) - - elif base_model_type in ["flux"]: - ui_defaults.update({ - "embedded_guidance": 2.5, - }) - if model_def.get("reference_image", False): - ui_defaults.update({ - "video_prompt_type": "KI", - }) - elif base_model_type in ["sky_df_1.3B", "sky_df_14B"]: - ui_defaults.update({ - "guidance_scale": 6.0, - "flow_shift": 8, - "sliding_window_discard_last_frames" : 0, - "resolution": "1280x720" if "720" in base_model_type else "960x544", - "sliding_window_size" : 121 if "720" in base_model_type else 97, - "RIFLEx_setting": 2, - "guidance_scale": 6, - "flow_shift": 8, - }) - - - elif base_model_type in ["phantom_1.3B", "phantom_14B"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "remove_background_images_ref": 1, - "video_prompt_type": "I", - # "resolution": "1280x720" - }) - - elif base_model_type in ["hunyuan_custom"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "resolution": "1280x720", - "video_prompt_type": "I", - }) - elif base_model_type in ["hunyuan_custom_audio"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "I", - }) - elif base_model_type in ["hunyuan_custom_edit"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "MVAI", - "sliding_window_size": 129, - }) - elif base_model_type in ["hunyuan_avatar"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "remove_background_images_ref": 0, - "skip_steps_start_step_perc": 25, - "video_length": 129, - "video_prompt_type": "I", - }) - elif base_model_type in ["vace_14B", "vace_multitalk_14B"]: - ui_defaults.update({ - "sliding_window_discard_last_frames": 0, - }) - + model_handler = get_model_handler(model_type) + model_handler.update_default_settings(base_model_type, model_def, ui_defaults) ui_defaults_update = model_def.get("settings", None) if ui_defaults_update is not None: ui_defaults.update(ui_defaults_update) @@ -2182,27 +2107,13 @@ def get_default_settings(model_type): ui_defaults["num_inference_steps"] = default_number_steps return ui_defaults -def get_model_query_handler(model_type): - base_model_type = get_base_model_type(model_type) - model_family= get_model_family(base_model_type) - if model_family == "wan": - if base_model_type in ("sky_df_1.3B", "sky_df_14B"): - from wan.diffusion_forcing import query_model_def - else: - from wan.any2video import query_model_def - elif model_family == "hunyuan": - from hyvideo.hunyuan import query_model_def - elif model_family == "ltxv": - from ltx_video.ltxv import query_model_def - elif model_family == "flux": - from flux.flux_main import query_model_def - else: - raise Exception(f"Unknown / unsupported model type {model_type}") - return query_model_def def init_model_def(model_type, model_def): - query_handler = get_model_query_handler(model_type) - default_model_def = query_handler(model_type, model_def) + base_model_type = get_base_model_type(model_type) + family_handler = model_types_handlers.get(base_model_type, None) + if family_handler is None: + raise Exception(f"Unknown model type {model_type}") + default_model_def = family_handler.query_model_def(base_model_type, model_def) if default_model_def is None: return model_def default_model_def.update(model_def) return default_model_def @@ -2282,8 +2193,11 @@ if len(args.vae_config) > 0: vae_config = int(args.vae_config) reload_needed = False -default_ui = server_config.get("default_ui", "t2v") -save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) +save_path = server_config.get("save_path", os.path.join(os.getcwd(), "outputs")) +image_save_path = server_config.get("image_save_path", os.path.join(os.getcwd(), "outputs")) +if not "video_output_codec" in server_config: server_config["video_output_codec"]= "libx264_8" +if not "image_output_codec" in server_config: server_config["image_output_codec"]= "jpeg_95" + preload_model_policy = server_config.get("preload_model_policy", []) @@ -2331,7 +2245,7 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1): model_filename = os.path.basename(url) break if model_filename is None: - print(f"No target filename mentioned in {url_key}") + print(f"No target filename with bf16 or fp16 in its name is mentioned in {url_key}") return if not os.path.isfile(model_filename): offload.save_model(model, os.path.join("ckpts",model_filename), config_file_path=config_file) @@ -2395,27 +2309,6 @@ def get_loras_preprocessor(transformer, model_type): return preprocessor_wrapper -def get_wan_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" - if text_encoder_quantization =="int8": - text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8") - return text_encoder_filename - -def get_ltxv_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" - if text_encoder_quantization =="int8": - text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") - return text_encoder_filename - -def get_hunyuan_text_encoder_filename(text_encoder_quantization): - if text_encoder_quantization =="int8": - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" - else: - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" - - return text_encoder_filename - - def process_files_def(repoId, sourceFolderList, fileList): targetRoot = "ckpts/" for sourceFolder, files in zip(sourceFolderList,fileList ): @@ -2440,7 +2333,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename, model_type, submodel_no = 1): +def download_models(model_filename, model_type, module_type = None, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2451,7 +2344,7 @@ def download_models(model_filename, model_type, submodel_no = 1): from urllib.request import urlretrieve - from wan.utils.utils import create_progress_hook + from shared.utils.utils import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", @@ -2495,17 +2388,21 @@ def download_models(model_filename, model_type, submodel_no = 1): else: urlretrieve(url,filename, create_progress_hook(filename)) - model_family = get_model_family(model_type) + base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) source = model_def.get("source", None) - - + model_type_handler = model_types_handlers[base_model_type] + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" if source is not None: model_filename = None - elif not model_type in modules_files: - if not os.path.isfile(model_filename ): + elif module_type is not None: + modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} + if module_type not in modules_files: + raise Exception(f"Unknown module {model_type} for model type {model_type}") + else: + if not os.path.isfile(model_filename): URLs = get_model_recursive_prop(model_type, key_name, return_list= False) if isinstance(URLs, str): raise Exception("Missing model " + URLs) @@ -2547,55 +2444,7 @@ def download_models(model_filename, model_type, submodel_no = 1): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'") - - if model_family == "wan": - text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], - "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] - } - elif model_family == "ltxv": - text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/LTX_Video", - "sourceFolderList" : ["T5_xxl_1.1", "" ], - "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] - } - elif model_family == "hunyuan": - text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/HunyuanVideo", - "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], - "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , - ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], - ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], - ["detface.pt"], - [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) - ] - } - elif model_family == "flux": - text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) - model_files = [ - { - "repoId" : "DeepBeepMeep/Flux", - "sourceFolderList" : [""], - "fileList" : [ ["flux_vae.safetensors"] ] - }, - { - "repoId" : "DeepBeepMeep/LTX_Video", - "sourceFolderList" : ["T5_xxl_1.1"], - "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ] - }, - { - "repoId" : "DeepBeepMeep/HunyuanVideo", - "sourceFolderList" : [ "clip_vit_large_patch14", ], - "fileList" :[ - ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], - ] - } - ] - + model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) if not isinstance(model_files, list): model_files = [model_files] for one_repo in model_files: process_files_def(**one_repo) @@ -2690,116 +2539,6 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl print(error[:200]) return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset - -def load_wan_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): - if test_class_i2v(base_model_type): - cfg = WAN_CONFIGS['i2v-14B'] - else: - cfg = WAN_CONFIGS['t2v-14B'] - # cfg = WAN_CONFIGS['t2v-1.3B'] - if base_model_type in ("sky_df_1.3B", "sky_df_14B"): - model_factory = wan.DTT2V - else: - model_factory = wan.WanAny2V - - wan_model = model_factory( - config=cfg, - checkpoint_dir="ckpts", - model_filename=model_filename, - model_type = model_type, - model_def = model_def, - base_model_type=base_model_type, - text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), - quantizeTransformer = quantizeTransformer, - dtype = dtype, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } - if hasattr(wan_model,"model2") and wan_model.model2 is not None: - pipe["transformer2"] = wan_model.model2 - if hasattr(wan_model, "clip"): - pipe["text_encoder_2"] = wan_model.clip.model - return wan_model, pipe - -def load_ltxv_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from ltx_video.ltxv import LTXV - - ltxv_model = LTXV( - model_filepath = model_filename, - text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), - model_type = model_type, - base_model_type = base_model_type, - model_def = model_def, - dtype = dtype, - # quantizeTransformer = quantizeTransformer, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer - ) - - pipeline = ltxv_model.pipeline - pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler} - - return ltxv_model, pipe - - -def load_flux_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from flux.flux_main import model_factory - - flux_model = model_factory( - checkpoint_dir="ckpts", - model_filename=model_filename, - model_type = model_type, - model_def = model_def, - base_model_type=base_model_type, - text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization), - quantizeTransformer = quantizeTransformer, - dtype = dtype, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5} - - return flux_model, pipe - -def load_hunyuan_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from hyvideo.hunyuan import HunyuanVideoSampler - - hunyuan_model = HunyuanVideoSampler.from_pretrained( - model_filepath = model_filename, - model_type = model_type, - base_model_type = base_model_type, - text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), - dtype = dtype, - quantizeTransformer = quantizeTransformer, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } - - if hunyuan_model.wav2vec != None: - pipe["wav2vec"] = hunyuan_model.wav2vec - - - # if hunyuan_model.align_instance != None: - # pipe["align_instance"] = hunyuan_model.align_instance.facedet.model - - - from hyvideo.modules.models import get_linear_split_map - - split_linear_modules_map = get_linear_split_map() - hunyuan_model.model.split_linear_modules_map = split_linear_modules_map - offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map ) - - - return hunyuan_model, pipe - def get_transformer_model(model, submodel_no = 1): if submodel_no > 1: model_key = f"model{submodel_no}" @@ -2848,17 +2587,20 @@ def load_models(model_type): preload = server_config.get("preload_in_VRAM", 0) model_file_list = [model_filename] model_type_list = [model_type] + module_type_list = [None] model_submodel_no_list = [1] if model_filename2 != None: model_file_list += [model_filename2] model_type_list += [model_type] + module_type_list += [None] model_submodel_no_list += [2] for module_type in modules: - model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True)) - model_type_list.append(module_type) + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) + model_type_list.append(model_type) + module_type_list.append(module_type) model_submodel_no_list.append(0) - for filename, file_model_type, submodel_no in zip(model_file_list, model_type_list, model_submodel_no_list): - download_models(filename, file_model_type, submodel_no) + for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list): + download_models(filename, file_model_type, file_module_type, submodel_no) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_type = None @@ -2868,16 +2610,10 @@ def load_models(model_type): else: print(f"Loading Module '{filename}' ...") - if model_family == "wan" : - wan_model, pipe = load_wan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "flux": - wan_model, pipe = load_flux_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "hunyuan": - wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - else: - raise Exception(f"Model '{model_filename}' not supported.") + wan_model, pipe = model_types_handlers[base_model_type].load_model( + model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, + dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + kwargs = { "extraModelsToQuantize": None } loras_transformer = ["transformer"] if profile in (2, 4, 5): @@ -2943,12 +2679,15 @@ def generate_header(model_type, compile, attention_mode): model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" description = description_container[0] header = f"
{description}
" - - header += "
Attention mode " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) + overridden_attention = get_overridden_attention(model_type) + attn_mode = attention_mode if overridden_attention == None else overridden_attention + header += "
Attention mode " + (attn_mode if attn_mode!="auto" else "auto/" + get_auto_attention() ) if attention_mode not in attention_modes_installed: header += " -NOT INSTALLED-" elif attention_mode not in attention_modes_supported: header += " -NOT SUPPORTED-" + elif overridden_attention is not None and attention_mode != overridden_attention: + header += " -MODEL SPECIFIC-" header += "" if compile: @@ -2971,6 +2710,7 @@ def apply_changes( state, VAE_precision_choice, mixed_precision_choice, save_path_choice, + image_save_path_choice, attention_choice, compile_choice, profile_choice, @@ -2990,6 +2730,9 @@ def apply_changes( state, notification_sound_volume_choice = 50, max_frames_multiplier_choice = 1, display_stats_choice = 0, + video_output_codec_choice = None, + image_output_codec_choice = None, + audio_output_codec_choice = None, last_resolution_choice = None, ): if args.lock_config: @@ -3002,6 +2745,7 @@ def apply_changes( state, "transformer_types": transformer_types_choices, "text_encoder_quantization" : text_encoder_quantization_choice, "save_path" : save_path_choice, + "image_save_path" : image_save_path_choice, "compile" : compile_choice, "profile" : profile_choice, "vae_config" : vae_config_choice, @@ -3023,6 +2767,9 @@ def apply_changes( state, "notification_sound_volume" : notification_sound_volume_choice, "max_frames_multiplier" : max_frames_multiplier_choice, "display_stats" : display_stats_choice, + "video_output_codec" : video_output_codec_choice, + "image_output_codec" : image_output_codec_choice, + "audio_output_codec" : audio_output_codec_choice, "last_model_type" : state["model_type"], "last_model_per_family": state["last_model_per_family"], "last_advanced_choice": state["advanced"], @@ -3056,6 +2803,7 @@ def apply_changes( state, vae_config = server_config["vae_config"] boost = server_config["boost"] save_path = server_config["save_path"] + image_save_path = server_config["image_save_path"] preload_model_policy = server_config["preload_model_policy"] transformer_quantization = server_config["transformer_quantization"] transformer_dtype_policy = server_config["transformer_dtype_policy"] @@ -3063,7 +2811,9 @@ def apply_changes( state, transformer_types = server_config["transformer_types"] model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename - if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ): + if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", + "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", + "video_output_codec", "image_output_codec", "audio_output_codec"] for change in changes ): model_family = gr.Dropdown() model_choice = gr.Dropdown() else: @@ -3074,18 +2824,6 @@ def apply_changes( state, mmaudio_enabled = server_config["mmaudio_enabled"] > 0 return "
The new configuration has been succesfully applied
", header, model_family, model_choice, gr.Row(visible= server_config["enhancer_enabled"] == 1), gr.Row(visible= mmaudio_enabled), gr.Column(visible= mmaudio_enabled) - - -from moviepy.editor import ImageSequenceClip -import numpy as np - -def save_video(final_frames, output_path, fps=24): - assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)" - if final_frames.dtype != np.uint8: - final_frames = (final_frames * 255).astype(np.uint8) - ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False) - - def get_gen_info(state): cache = state.get("gen", None) if cache == None: @@ -3387,6 +3125,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): + [ v for s,v in map_audio_prompt.items() if all_letters(video_audio_prompt_type,s)] video_model_type = configs.get("model_type", "t2v") model_family = get_model_family(video_model_type) + model_def = get_model_def(video_model_type) video_other_prompts = ", ".join(video_other_prompts) video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" video_length = configs.get("video_length", 0) @@ -3405,8 +3144,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_guidance_scale = configs.get("guidance_scale", None) video_guidance2_scale = configs.get("guidance2_scale", None) video_switch_threshold = configs.get("switch_threshold", 0) - video_embedded_guidance_scale = configs.get("embedded_guidance_scale ", None) - if model_family in ["hunyuan", "flux"]: + video_embedded_guidance_scale = configs.get("embedded_guidance_scale", None) + if model_def.get("embedded_guidance", False): video_guidance_scale = video_embedded_guidance_scale video_guidance_label = "Embedded Guidance Scale" else: @@ -3439,8 +3178,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") - if model_family == "wan": - values += ["unipc" if len(video_sample_solver) ==0 else video_sample_solver] + if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 : + values += [video_sample_solver] labels += ["Sampler Solver"] values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"] @@ -3506,7 +3245,7 @@ def convert_image(image): return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): - from wan.utils.utils import resample + from shared.utils.utils import resample import decord decord.bridge.set_bridge(bridge) @@ -3610,7 +3349,7 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis return results def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): - from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions + from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -3895,7 +3634,7 @@ def perform_temporal_upsampling(sample, previous_last_frame, temporal_upsampling def perform_spatial_upsampling(sample, spatial_upsampling): - from wan.utils.utils import resize_lanczos + from shared.utils.utils import resize_lanczos if spatial_upsampling == "lanczos1.5": scale = 1.5 else: @@ -3916,7 +3655,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): def any_audio_track(model_type): base_model_type = get_base_model_type(model_type) - return base_model_type in ["fantasy", "multitalk", "hunyuan_avatar", "hunyuan_custom_audio", "vace_multitalk_14B"] + return base_model_type in ["fantasy", "hunyuan_avatar", "hunyuan_custom_audio"] or get_model_def(model_type).get("multitalk_class", False) def get_available_filename(target_path, video_source, suffix = "", force_extension = None): name, extension = os.path.splitext(os.path.basename(video_source)) @@ -3990,7 +3729,7 @@ def edit_video( seed = set_seed(seed) - from wan.utils.utils import get_video_info + from shared.utils.utils import get_video_info fps, width, height, frames_count = get_video_info(video_source) frames_count = min(frames_count, max_source_video_frames) sample = None @@ -4025,7 +3764,7 @@ def edit_video( any_change = False if sample != None: video_path =get_available_filename(save_path, video_source, "_tmp") if any_mmaudio or has_already_audio else get_available_filename(save_path, video_source, "_post") - cache_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) + save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None)) if any_mmaudio or has_already_audio: tmp_path = video_path any_change = True @@ -4093,6 +3832,17 @@ def edit_video( cleanup_temp_audio_files(audio_tracks) clear_status(state) +def get_overridden_attention(model_type): + model_def = get_model_def(model_type) + override_attention = model_def.get("attention", None) + if override_attention is None: return None + gpu_version = gpu_major * 10 + gpu_minor + attention_list = match_nvidia_architecture(override_attention, gpu_version) + if len(attention_list ) == 0: return None + override_attention = attention_list[0] + if override_attention is not None and override_attention not in attention_modes_supported: return None + return override_attention + def get_transformer_loras(model_type): model_def = get_model_def(model_type) transformer_loras_filenames = get_model_recursive_prop(model_type, "loras", return_list=True) @@ -4102,6 +3852,36 @@ def get_transformer_loras(model_type): transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)] return transformer_loras_filenames, transformer_loras_multipliers +class DynamicClass: + def __init__(self, **kwargs): + self._data = {} + # Preassign default properties from kwargs + for key, value in kwargs.items(): + self._data[key] = value + + def __getattr__(self, name): + if name in self._data: + return self._data[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name, value): + if name.startswith('_'): + super().__setattr__(name, value) + else: + if not hasattr(self, '_data'): + super().__setattr__('_data', {}) + self._data[name] = value + + def assign(self, **kwargs): + """Assign multiple properties at once""" + for key, value in kwargs.items(): + self._data[key] = value + return self # For method chaining + + def update(self, dict): + """Alias for assign() - more dict-like""" + return self.assign(**dict) + def generate_video( task, send_cmd, @@ -4184,7 +3964,13 @@ def generate_video( model_filename, mode, ): - + # import os + # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding + # import torch._logging as tlog + # tlog.set_logs(recompiles=True, guards=True, graph_breaks=True) + + + def remove_temp_filenames(temp_filenames_list): for temp_filename in temp_filenames_list: if temp_filename!= None and os.path.isfile(temp_filename): @@ -4194,7 +3980,7 @@ def generate_video( process_map_video_guide = { "P": "pose", "D" : "depth", "S": "scribble", "E": "canny", "L": "flow", "C": "gray", "M": "inpaint", "U": "identity"} processes_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "identity": "Identity Mask", "raw" : "Raw Format", "canny" : "Canny Edges"} - global wan_model, offloadobj, reload_needed, save_path + global wan_model, offloadobj, reload_needed gen = get_gen_info(state) torch.set_grad_enabled(False) if mode.startswith("edit_"): @@ -4223,10 +4009,12 @@ def generate_video( temp_filenames_list.append(video_mask) image_mask = None + base_model_type = get_base_model_type(model_type) fit_canvas = server_config.get("fit_canvas", 0) + model_handler = get_model_handler(base_model_type) + block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16 - if "P" in preload_model_policy and not "U" in preload_model_policy: while wan_model == None: time.sleep(1) @@ -4241,19 +4029,18 @@ def generate_video( wan_model, offloadobj = load_models(model_type) send_cmd("status", "Model loaded") reload_needed= False - - if attention_mode == "auto": + overridden_attention = get_overridden_attention(model_type) + # if overridden_attention is not None and overridden_attention != attention_mode: print(f"Attention mode has been overriden to {overridden_attention} for model type '{model_type}'") + attn = overridden_attention if overridden_attention is not None else attention_mode + if attn == "auto": attn = get_auto_attention() - elif attention_mode in attention_modes_supported: - attn = attention_mode - else: + elif not attn in attention_modes_supported: send_cmd("info", f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.") send_cmd("exit") return width, height = resolution.split("x") width, height = int(width), int(height) - resolution_reformated = str(height) + "*" + str(width) default_image_size = (height, width) if slg_switch == 0: @@ -4261,12 +4048,14 @@ def generate_video( offload.shared_state["_attention"] = attn device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 - VAE_tile_size = wan_model.vae.get_VAE_tile_size(vae_config, device_mem_capacity, server_config.get("vae_precision", "16") == "32") + if hasattr(wan_model.vae, "get_VAE_tile_size"): + VAE_tile_size = wan_model.vae.get_VAE_tile_size(vae_config, device_mem_capacity, server_config.get("vae_precision", "16") == "32") + else: + VAE_tile_size = None trans = get_transformer_model(wan_model) trans2 = get_transformer_model(wan_model, 2) audio_sampling_rate = 16000 - base_model_type = get_base_model_type(model_type) prompts = prompt.split("\n") prompts = [part for part in prompts if len(prompt)>0] @@ -4325,11 +4114,11 @@ def generate_video( hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] - multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] + multitalk = model_def.get("multitalk_class", False) flux = base_model_type in ["flux"] if "B" in audio_prompt_type or "X" in audio_prompt_type: - from wan.multitalk.multitalk import parse_speakers_locations + from models.wan.multitalk.multitalk import parse_speakers_locations speakers_bboxes, error = parse_speakers_locations(speakers_locations) else: speakers_bboxes = None @@ -4373,7 +4162,7 @@ def generate_video( for i, pos in enumerate(frames_positions_list): frames_to_inject[pos] = image_refs[i] if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) : - from wan.utils.utils import get_outpainting_full_area_dimensions + from shared.utils.utils import get_outpainting_full_area_dimensions w, h = image_refs[0].size if outpainting_dims != None: h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) @@ -4384,55 +4173,31 @@ def generate_video( if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") - from wan.utils.utils import resize_and_remove_background + from shared.utils.utils import resize_and_remove_background image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux) ) # no fit for vace ref images as it is done later update_task_thumbnails(task, locals()) send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 - trans.enable_cache = None if len(skip_steps_cache_type) == 0 else skip_steps_cache_type - if trans2 is not None: - trans2.enable_cache = None + + skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type) - if trans.enable_cache != None: - trans.cache_multiplier = skip_steps_multiplier - trans.cache_start_step = int(skip_steps_start_step_perc*num_inference_steps/100) - - if trans.enable_cache == "mag": - trans.magcache_thresh = 0 - trans.magcache_K = 2 - def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None - if def_mag_ratios != None: - trans.def_mag_ratios = def_mag_ratios - elif get_model_family(model_type) == "wan": - if i2v: - trans.def_mag_ratios = np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939])#**(0.5)# In our papaer, we utilize the sqrt to smooth the ratio, which has little impact on the performance and can be deleted. - else: - trans.def_mag_ratios = np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]) + if skip_steps_cache != None: + skip_steps_cache.update({ + "multiplier" : skip_steps_multiplier, + "start_step": int(skip_steps_start_step_perc*num_inference_steps/100) + }) + model_handler.set_cache_parameters(skip_steps_cache_type, base_model_type, model_def, locals(), skip_steps_cache) + if skip_steps_cache_type == "mag": + def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None + if def_mag_ratios is not None: skip_steps_cache.def_mag_ratios = def_mag_ratios + elif skip_steps_cache_type == "tea": + def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None + if def_tea_coefficients is not None: skip_steps_cache.coefficients = def_tea_coefficients else: - if width * height >= 1280* 720: - trans.def_mag_ratios = np.array([1.0]+[1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456]) - else: - trans.def_mag_ratios = np.array([1.0]+[1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592]) + raise Exception(f"unknown cache type {skip_steps_cache_type}") + trans.cache = skip_steps_cache + if trans2 is not None: trans2.cache = skip_steps_cache - elif trans.enable_cache == "tea": - trans.rel_l1_thresh = 0 - model_def = get_model_def(model_type) - def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None - if def_tea_coefficients != None: - trans.coefficients = def_tea_coefficients - elif get_model_family(model_type) == "wan": - if i2v: - if '720p' in model_filename: - trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] - else: - trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - else: - if '1.3B' in model_filename: - trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] - elif '14B' in model_filename: - trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] - else: - raise gr.Error("Teacache not supported for this model") output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide @@ -4441,7 +4206,7 @@ def generate_video( audio_scale = None audio_context_lens = None if (fantasy or multitalk or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None: - from wan.fantasytalking.infer import parse_audio + from models.wan.fantasytalking.infer import parse_audio import librosa duration = librosa.get_duration(path=audio_guide) combination_type = "add" @@ -4465,7 +4230,7 @@ def generate_video( # audio_proj_split_full, audio_context_lens_full = parse_audio(audio_guide, num_frames= max_source_video_frames, fps= fps, padded_frames_for_embeddings= (reuse_frames if reset_control_aligment else 0), device= processing_device ) audio_scale = 1.0 elif multitalk: - from wan.multitalk.multitalk import get_full_audio_embeddings + from models.wan.multitalk.multitalk import get_full_audio_embeddings # pad audio_proj_full if aligned to beginning of window to simulate source window overlap audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0)) if output_new_audio_filepath is not None: output_new_audio_data = None @@ -4483,6 +4248,7 @@ def generate_video( torch.set_grad_enabled(False) os.makedirs(save_path, exist_ok=True) + os.makedirs(image_save_path, exist_ok=True) gc.collect() torch.cuda.empty_cache() wan_model._interrupt = False @@ -4538,7 +4304,7 @@ def generate_video( if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0: text_encoder_max_tokens = 256 send_cmd("progress", [0, get_latest_status(state, "Enhancing Prompt")]) - from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt + from models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt prompt_images = [] if "I" in prompt_enhancer: if image_start != None: @@ -4548,7 +4314,7 @@ def generate_video( if len(original_prompts) == 0 and not "T" in prompt_enhancer: pass else: - from wan.utils.utils import seed_everything + from shared.utils.utils import seed_everything seed_everything(seed) # for i, original_prompt in enumerate(original_prompts): prompts = generate_cinematic_prompt( @@ -4607,9 +4373,9 @@ def generate_video( image_end_tensor = torch.from_numpy(np.array(image_end_tensor).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) else: if "L" in image_prompt_type: - from wan.utils.utils import get_video_frame + from shared.utils.utils import get_video_frame refresh_preview["video_source"] = get_video_frame(video_source, 0) - prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = 32 if ltxv else 16) + prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size ) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w pre_video_guide = prefix_video[:, -reuse_frames:] @@ -4628,7 +4394,7 @@ def generate_video( if fantasy: audio_proj_split , audio_context_lens = parse_audio(audio_guide, start_frame = aligned_window_start_frame, num_frames= current_video_length, fps= fps, device= processing_device ) if multitalk: - from wan.multitalk.multitalk import get_window_audio_embeddings + from models.wan.multitalk.multitalk import get_window_audio_embeddings # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) @@ -4643,7 +4409,7 @@ def generate_video( status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) # start one frame ealier to faciliate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =32 ) + src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) if src_video != None: src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) @@ -4754,12 +4520,13 @@ def generate_video( progress_args = [0, merge_status_context(status, "Encoding Prompt")] send_cmd("progress", progress_args) - if trans.enable_cache != None: - trans.num_steps = num_inference_steps - trans.cache_skipped_steps = 0 - trans.previous_residual = None - trans.previous_modulated_input = None - + if skip_steps_cache != None: + skip_steps_cache.update({ + "num_steps" : num_inference_steps, + "skipped_steps" : 0, + "previous_residual": None, + "previous_modulated_input": None, + }) # samples = torch.empty( (1,2)) #for testing # if False: @@ -4824,6 +4591,7 @@ def generate_video( speakers_bboxes =speakers_bboxes, image_mode = image_mode, video_prompt_type= video_prompt_type, + window_no = window_no, offloadobj = offloadobj, ) except Exception as e: @@ -4831,8 +4599,12 @@ def generate_video( cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks) remove_temp_filenames(temp_filenames_list) offloadobj.unload_all() + trans.cache = None offload.unload_loras_from_model(trans) - if trans is not None: offload.unload_loras_from_model(trans) + if trans2 is not None: + trans2.cache = None + offload.unload_loras_from_model(trans2) + skip_steps_cache = None # if compile: # cache_size = torch._dynamo.config.cache_size_limit # torch.compiler.reset() @@ -4859,12 +4631,11 @@ def generate_video( send_cmd("error", new_error) clear_status(state) return - finally: - trans.previous_residual = None - trans.previous_modulated_input = None - if trans.enable_cache != None : - print(f"Skipped Steps:{trans.cache_skipped_steps}/{trans.num_steps}" ) + if skip_steps_cache != None : + skip_steps_cache.previous_residual = None + skip_steps_cache.previous_modulated_input = None + print(f"Skipped Steps:{skip_steps_cache.skipped_steps}/{skip_steps_cache.num_steps}" ) if samples != None: if isinstance(samples, dict): @@ -4938,7 +4709,7 @@ def generate_video( time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") save_prompt = original_prompts[0] - from wan.utils.utils import truncate_for_filesystem + from shared.utils.utils import truncate_for_filesystem extension = "jpg" if is_image else "mp4" if os.name == 'nt': @@ -4948,18 +4719,19 @@ def generate_video( video_path = os.path.join(save_path, file_name) any_mmaudio = MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and sample.shape[1] >=fps - if is_image: - sample = sample.permute(1,2,3,0) #c f h w -> f h w c - new_video_path = [] + if is_image: + image_path = os.path.join(image_save_path, file_name) + sample = sample.transpose(1,0) #c f h w -> f c h w + new_image_path = [] for no, img in enumerate(sample): - img = Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy()) - img_path = os.path.splitext(video_path)[0] + ("" if no==0 else f"_{no}") + ".jpg" - new_video_path.append(img_path) - img.save(img_path) - video_path= new_video_path + img_path = os.path.splitext(image_path)[0] + ("" if no==0 else f"_{no}") + ".jpg" + new_image_path.append(save_image(img, save_file = img_path, quality = server_config.get("image_output_codec", None))) + + video_path= new_image_path elif len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0 or output_new_audio_filepath is not None or any_mmaudio or output_new_audio_data is not None or audio_source is not None: + video_path = os.path.join(save_path, file_name) save_path_tmp = video_path[:-4] + "_tmp.mp4" - cache_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) + save_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type = server_config.get("video_output_codec", None)) output_new_audio_temp_filepath = None new_audio_from_start = reset_control_aligment source_audio_duration = source_video_frames_count / fps @@ -4986,7 +4758,7 @@ def generate_video( if output_new_audio_temp_filepath is not None: os.remove(output_new_audio_temp_filepath) else: - cache_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) + save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None)) end_time = time.time() @@ -4996,6 +4768,11 @@ def generate_video( inputs.pop("mode") inputs["model_type"] = model_type inputs["model_filename"] = original_filename + if is_image: + inputs["image_quality"] = server_config.get("image_output_codec", None) + else: + inputs["video_quality"] = server_config.get("video_output_codec", None) + modules = get_model_recursive_prop(model_type, "modules", return_list= True) if len(modules) > 0 : inputs["modules"] = modules if len(transformer_loras_filenames) > 0: @@ -5018,8 +4795,7 @@ def generate_video( json.dump(configs, f, indent=4) elif metadata_choice == "metadata": if is_image: - with Image.open(path) as img: - img.save(path, comment=json.dumps(configs)) + save_image_metadata(path, configs) else: from mutagen.mp4 import MP4 file = MP4(path) @@ -5048,8 +4824,10 @@ def generate_video( seed = set_seed(-1) clear_status(state) + trans.cache = None offload.unload_loras_from_model(trans) if not trans2 is None: + trans2.cache = None offload.unload_loras_from_model(trans2) if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -5064,222 +4842,17 @@ def prepare_generate_video(state): else: return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True), gr.update(visible= False) -def generate_preview(latents): + +def generate_preview(model_type, latents): import einops - # thanks Comfyui for the rgb factors - model_family = get_model_family(transformer_type) - if model_family == "wan": - latent_channels = 16 - latent_dimensions = 3 - latent_rgb_factors = [ - [-0.1299, -0.1692, 0.2932], - [ 0.0671, 0.0406, 0.0442], - [ 0.3568, 0.2548, 0.1747], - [ 0.0372, 0.2344, 0.1420], - [ 0.0313, 0.0189, -0.0328], - [ 0.0296, -0.0956, -0.0665], - [-0.3477, -0.4059, -0.2925], - [ 0.0166, 0.1902, 0.1975], - [-0.0412, 0.0267, -0.1364], - [-0.1293, 0.0740, 0.1636], - [ 0.0680, 0.3019, 0.1128], - [ 0.0032, 0.0581, 0.0639], - [-0.1251, 0.0927, 0.1699], - [ 0.0060, -0.0633, 0.0005], - [ 0.3477, 0.2275, 0.2950], - [ 0.1984, 0.0913, 0.1861] - ] - - # credits for the rgb factors to ComfyUI ? - - latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] - - # latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] - elif model_family =="flux": - scale_factor = 0.3611 - shift_factor = 0.1159 - latent_rgb_factors =[ - [-0.0346, 0.0244, 0.0681], - [ 0.0034, 0.0210, 0.0687], - [ 0.0275, -0.0668, -0.0433], - [-0.0174, 0.0160, 0.0617], - [ 0.0859, 0.0721, 0.0329], - [ 0.0004, 0.0383, 0.0115], - [ 0.0405, 0.0861, 0.0915], - [-0.0236, -0.0185, -0.0259], - [-0.0245, 0.0250, 0.1180], - [ 0.1008, 0.0755, -0.0421], - [-0.0515, 0.0201, 0.0011], - [ 0.0428, -0.0012, -0.0036], - [ 0.0817, 0.0765, 0.0749], - [-0.1264, -0.0522, -0.1103], - [-0.0280, -0.0881, -0.0499], - [-0.1262, -0.0982, -0.0778] - ] - latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] - - elif model_family == "ltxv": - latent_channels = 128 - latent_dimensions = 3 - - latent_rgb_factors = [ - [ 1.1202e-02, -6.3815e-04, -1.0021e-02], - [ 8.6031e-02, 6.5813e-02, 9.5409e-04], - [-1.2576e-02, -7.5734e-03, -4.0528e-03], - [ 9.4063e-03, -2.1688e-03, 2.6093e-03], - [ 3.7636e-03, 1.2765e-02, 9.1548e-03], - [ 2.1024e-02, -5.2973e-03, 3.4373e-03], - [-8.8896e-03, -1.9703e-02, -1.8761e-02], - [-1.3160e-02, -1.0523e-02, 1.9709e-03], - [-1.5152e-03, -6.9891e-03, -7.5810e-03], - [-1.7247e-03, 4.6560e-04, -3.3839e-03], - [ 1.3617e-02, 4.7077e-03, -2.0045e-03], - [ 1.0256e-02, 7.7318e-03, 1.3948e-02], - [-1.6108e-02, -6.2151e-03, 1.1561e-03], - [ 7.3407e-03, 1.5628e-02, 4.4865e-04], - [ 9.5357e-04, -2.9518e-03, -1.4760e-02], - [ 1.9143e-02, 1.0868e-02, 1.2264e-02], - [ 4.4575e-03, 3.6682e-05, -6.8508e-03], - [-4.5681e-04, 3.2570e-03, 7.7929e-03], - [ 3.3902e-02, 3.3405e-02, 3.7454e-02], - [-2.3001e-02, -2.4877e-03, -3.1033e-03], - [ 5.0265e-02, 3.8841e-02, 3.3539e-02], - [-4.1018e-03, -1.1095e-03, 1.5859e-03], - [-1.2689e-01, -1.3107e-01, -2.1005e-01], - [ 2.6276e-02, 1.4189e-02, -3.5963e-03], - [-4.8679e-03, 8.8486e-03, 7.8029e-03], - [-1.6610e-03, -4.8597e-03, -5.2060e-03], - [-2.1010e-03, 2.3610e-03, 9.3796e-03], - [-2.2482e-02, -2.1305e-02, -1.5087e-02], - [-1.5753e-02, -1.0646e-02, -6.5083e-03], - [-4.6975e-03, 5.0288e-03, -6.7390e-03], - [ 1.1951e-02, 2.0712e-02, 1.6191e-02], - [-6.3704e-03, -8.4827e-03, -9.5483e-03], - [ 7.2610e-03, -9.9326e-03, -2.2978e-02], - [-9.1904e-04, 6.2882e-03, 9.5720e-03], - [-3.7178e-02, -3.7123e-02, -5.6713e-02], - [-1.3373e-01, -1.0720e-01, -5.3801e-02], - [-5.3702e-03, 8.1256e-03, 8.8397e-03], - [-1.5247e-01, -2.1437e-01, -2.1843e-01], - [ 3.1441e-02, 7.0335e-03, -9.7541e-03], - [ 2.1528e-03, -8.9817e-03, -2.1023e-02], - [ 3.8461e-03, -5.8957e-03, -1.5014e-02], - [-4.3470e-03, -1.2940e-02, -1.5972e-02], - [-5.4781e-03, -1.0842e-02, -3.0204e-03], - [-6.5347e-03, 3.0806e-03, -1.0163e-02], - [-5.0414e-03, -7.1503e-03, -8.9686e-04], - [-8.5851e-03, -2.4351e-03, 1.0674e-03], - [-9.0016e-03, -9.6493e-03, 1.5692e-03], - [ 5.0914e-03, 1.2099e-02, 1.9968e-02], - [ 1.3758e-02, 1.1669e-02, 8.1958e-03], - [-1.0518e-02, -1.1575e-02, -4.1307e-03], - [-2.8410e-02, -3.1266e-02, -2.2149e-02], - [ 2.9336e-03, 3.6511e-02, 1.8717e-02], - [-1.6703e-02, -1.6696e-02, -4.4529e-03], - [ 4.8818e-02, 4.0063e-02, 8.7410e-03], - [-1.5066e-02, -5.7328e-04, 2.9785e-03], - [-1.7613e-02, -8.1034e-03, 1.3086e-02], - [-9.2633e-03, 1.0803e-02, -6.3489e-03], - [ 3.0851e-03, 4.7750e-04, 1.2347e-02], - [-2.2785e-02, -2.3043e-02, -2.6005e-02], - [-2.4787e-02, -1.5389e-02, -2.2104e-02], - [-2.3572e-02, 1.0544e-03, 1.2361e-02], - [-7.8915e-03, -1.2271e-03, -6.0968e-03], - [-1.1478e-02, -1.2543e-03, 6.2679e-03], - [-5.4229e-02, 2.6644e-02, 6.3394e-03], - [ 4.4216e-03, -7.3338e-03, -1.0464e-02], - [-4.5013e-03, 1.6082e-03, 1.4420e-02], - [ 1.3673e-02, 8.8877e-03, 4.1253e-03], - [-1.0145e-02, 9.0072e-03, 1.5695e-02], - [-5.6234e-03, 1.1847e-03, 8.1261e-03], - [-3.7171e-03, -5.3538e-03, 1.2590e-03], - [ 2.9476e-02, 2.1424e-02, 3.0424e-02], - [-3.4925e-02, -2.4340e-02, -2.5316e-02], - [-3.4127e-02, -2.2406e-02, -1.0589e-02], - [-1.7342e-02, -1.3249e-02, -1.0719e-02], - [-2.1478e-03, -8.6051e-03, -2.9878e-03], - [ 1.2089e-03, -4.2391e-03, -6.8569e-03], - [ 9.0411e-04, -6.6886e-03, -6.7547e-05], - [ 1.6048e-02, -1.0057e-02, -2.8929e-02], - [ 1.2290e-03, 1.0163e-02, 1.8861e-02], - [ 1.7264e-02, 2.7257e-04, 1.3785e-02], - [-1.3482e-02, -3.6427e-03, 6.7481e-04], - [ 4.6782e-03, -5.2423e-03, 2.4467e-03], - [-5.9113e-03, -6.2244e-03, -1.8162e-03], - [ 1.5496e-02, 1.4582e-02, 1.9514e-03], - [ 7.4958e-03, 1.5886e-03, -8.2305e-03], - [ 1.9086e-02, 1.6360e-03, -3.9674e-03], - [-5.7021e-03, -2.7307e-03, -4.1066e-03], - [ 1.7450e-03, 1.4602e-02, 2.5794e-02], - [-8.2788e-04, 2.2902e-03, 4.5161e-03], - [ 1.1632e-02, 8.9193e-03, -7.2813e-03], - [ 7.5721e-03, 2.6784e-03, 1.1393e-02], - [ 5.1939e-03, 3.6903e-03, 1.4049e-02], - [-1.8383e-02, -2.2529e-02, -2.4477e-02], - [ 5.8842e-04, -5.7874e-03, -1.4770e-02], - [-1.6125e-02, -8.6101e-03, -1.4533e-02], - [ 2.0540e-02, 2.0729e-02, 6.4338e-03], - [ 3.3587e-03, -1.1226e-02, -1.6444e-02], - [-1.4742e-03, -1.0489e-02, 1.7097e-03], - [ 2.8130e-02, 2.3546e-02, 3.2791e-02], - [-1.8532e-02, -1.2842e-02, -8.7756e-03], - [-8.0533e-03, -1.0771e-02, -1.7536e-02], - [-3.9009e-03, 1.6150e-02, 3.3359e-02], - [-7.4554e-03, -1.4154e-02, -6.1910e-03], - [ 3.4734e-03, -1.1370e-02, -1.0581e-02], - [ 1.1476e-02, 3.9281e-03, 2.8231e-03], - [ 7.1639e-03, -1.4741e-03, -3.8066e-03], - [ 2.2250e-03, -8.7552e-03, -9.5719e-03], - [ 2.4146e-02, 2.1696e-02, 2.8056e-02], - [-5.4365e-03, -2.4291e-02, -1.7802e-02], - [ 7.4263e-03, 1.0510e-02, 1.2705e-02], - [ 6.2669e-03, 6.2658e-03, 1.9211e-02], - [ 1.6378e-02, 9.4933e-03, 6.6971e-03], - [ 1.7173e-02, 2.3601e-02, 2.3296e-02], - [-1.4568e-02, -9.8279e-03, -1.1556e-02], - [ 1.4431e-02, 1.4430e-02, 6.6362e-03], - [-6.8230e-03, 1.8863e-02, 1.4555e-02], - [ 6.1156e-03, 3.4700e-03, -2.6662e-03], - [-2.6983e-03, -5.9402e-03, -9.2276e-03], - [ 1.0235e-02, 7.4173e-03, -7.6243e-03], - [-1.3255e-02, 1.9322e-02, -9.2153e-04], - [ 2.4222e-03, -4.8039e-03, -1.5759e-02], - [ 2.6244e-02, 2.5951e-02, 2.0249e-02], - [ 1.5711e-02, 1.8498e-02, 2.7407e-03], - [-2.1714e-03, 4.7214e-03, -2.2443e-02], - [-7.4747e-03, 7.4166e-03, 1.4430e-02], - [-8.3906e-03, -7.9776e-03, 9.7927e-03], - [ 3.8321e-02, 9.6622e-03, -1.9268e-02], - [-1.4605e-02, -6.7032e-03, 3.9675e-03] - ] - latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] - - elif model_family == "hunyuan": - latent_channels = 16 - latent_dimensions = 3 - scale_factor = 0.476986 - latent_rgb_factors = [ - [-0.0395, -0.0331, 0.0445], - [ 0.0696, 0.0795, 0.0518], - [ 0.0135, -0.0945, -0.0282], - [ 0.0108, -0.0250, -0.0765], - [-0.0209, 0.0032, 0.0224], - [-0.0804, -0.0254, -0.0639], - [-0.0991, 0.0271, -0.0669], - [-0.0646, -0.0422, -0.0400], - [-0.0696, -0.0595, -0.0894], - [-0.0799, -0.0208, -0.0375], - [ 0.1166, 0.1627, 0.0962], - [ 0.1165, 0.0432, 0.0407], - [-0.2315, -0.1920, -0.1355], - [-0.0270, 0.0401, -0.0821], - [-0.0616, -0.0997, -0.0727], - [ 0.0249, -0.0469, -0.1703] - ] - - latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + if latents is None: return None + model_handler = get_model_handler(model_type) + base_model_type = get_base_model_type(model_type) + if hasattr(model_handler, "get_rgb_factors"): + latent_rgb_factors, latent_rgb_factors_bias = model_handler.get_rgb_factors(base_model_type ) else: - raise Exception("preview not supported") + return None + if latent_rgb_factors is None: return None latents = latents.unsqueeze(0) nb_latents = latents.shape[2] latents_to_preview = 4 @@ -5310,7 +4883,7 @@ def generate_preview(latents): def process_tasks(state): - from wan.utils.thread_utils import AsyncStream, async_run + from shared.utils.thread_utils import AsyncStream, async_run gen = get_gen_info(state) queue = gen.get("queue", []) @@ -5392,7 +4965,7 @@ def process_tasks(state): # progress(*data) elif cmd == "preview": torch.cuda.current_stream().synchronize() - preview= None if data== None else generate_preview(data) + preview= None if data== None else generate_preview(params["model_type"], data) gen["preview"] = preview # yield time.time() , gr.Text() else: @@ -5942,11 +5515,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if "force_fps" in inputs and len(inputs["force_fps"])== 0: pop += ["force_fps"] - if not get_model_family(model_type) == "wan" or diffusion_forcing: + if model_def.get("sample_solvers", None) is None: pop += ["sample_solver"] - if not (test_class_i2v(base_model_type) or diffusion_forcing or ltxv or recammaster or vace): - pop += ["image_prompt_type"] + # if not (test_class_i2v(base_model_type) or diffusion_forcing or ltxv or recammaster or vace): + # pop += ["image_prompt_type"] if any_audio_track(base_model_type) or server_config.get("mmaudio_enabled", 0) == 0: pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"] @@ -5962,8 +5535,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["model_mode"] if not vace and not phantom and not hunyuan_video_custom: - unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_images_ref", "mask_expand"] - if base_model_type in ["t2v"]: unsaved_params = unsaved_params[2:] + unsaved_params = ["keep_frames_video_guide", "remove_background_images_ref", "mask_expand"] #"video_prompt_type", + if base_model_type in ["t2v"]: unsaved_params = unsaved_params[1:] pop += unsaved_params if not vace: pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2", "min_frames_if_references"] @@ -5974,27 +5547,42 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not test_any_sliding_window( base_model_type): pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"] - if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]: + if not (base_model_type in ["fantasy"] or model_def.get("multitalk_class", False)): pop += ["audio_guidance_scale", "speakers_locations"] - if not model_family in ["hunyuan", "flux"] or model_def.get("no_guidance", False): + if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False): pop += ["embedded_guidance_scale"] - if not model_family in ["hunyuan", "wan"]: + if not (model_def.get("tea_cache", False) or model_def.get("mag_cache", False)) : pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"] - if model_def.get("no_guidance", False) or ltxv or model_family in ["hunyuan", "flux"] : + if model_def.get("no_guidance", False) : pop += ["guidance_scale", "guidance2_scale", "switch_threshold", "audio_guidance_scale"] + + if not model_def.get("guidance_max_phases",1) >1: + pop += ["guidance2_scale", "switch_threshold"] + if model_def.get("image_outputs", False) or ltxv: pop += ["flow_shift"] - if model_def.get("no_negative_prompt", False) or model_family in ["flux"]: - pop += ["negative_prompt", "apg_switch", "cfg_star_switch", "cfg_zero_step", ] + if model_def.get("no_negative_prompt", False) : + pop += ["negative_prompt" ] + if not model_def.get("skip_layer_guidance", False): + pop += ["slg_switch", "slg_layers", "slg_start_perc", "slg_end_perc"] + + if not model_def.get("cfg_zero", False): + pop += [ "cfg_zero_step" ] + + if not model_def.get("cfg_star", False): + pop += ["cfg_star_switch" ] + + if not model_def.get("adaptive_projected_guidance", False): + pop += ["apg_switch"] if not model_family == "wan" or diffusion_forcing: - pop +=["NAG_scale", "NAG_tau", "NAG_alpha", "slg_switch", "slg_layers", "slg_start_perc", "slg_end_perc" ] + pop +=["NAG_scale", "NAG_tau", "NAG_alpha" ] for k in pop: if k in inputs: inputs.pop(k) @@ -6118,7 +5706,7 @@ def has_video_file_extension(filename): def has_image_file_extension(filename): extension = os.path.splitext(filename)[-1] - return extension in [".jpeg", ".jpg", ".png", ".bmp", ".tiff"] + return extension in [".jpeg", ".jpg", ".png", ".webp", ".bmp", ".tiff"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) @@ -6223,7 +5811,7 @@ def use_video_settings(state, input_file_list, choice): def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible): configs = None - tags = None + any_image_or_video = False if file_path.endswith(".json") and allow_json: try: with open(file_path, 'r', encoding='utf-8') as f: @@ -6235,22 +5823,22 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw try: file = MP4(file_path) tags = file.tags['©cmt'][0] + configs = json.loads(tags) + any_image_or_video = True except: pass elif has_image_file_extension(file_path): try: - with Image.open(file_path) as img: - tags = img.info["comment"] + configs = read_image_metadata(file_path) + any_image_or_video = True except: pass - if tags is not None: - try: - configs = json.loads(tags) - if not "WanGP" in configs.get("type", ""): configs = None - except: - configs = None - if configs == None: - return None, False + if configs is None: return None, False + try: + if not "WanGP" in configs.get("type", ""): configs = None + except: + configs = None + current_model_filename = state["model_filename"] current_model_type = state["model_type"] @@ -6276,7 +5864,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw configs = defaults configs["model_type"] = model_type - return configs, tags != None + return configs, any_image_or_video def record_image_mode_tab(state, evt:gr.SelectData): state["image_mode_tab"] = 0 if evt.index ==0 else 1 @@ -6692,11 +6280,13 @@ def refresh_video_guide_outpainting_row(video_guide_outpainting_checkbox, video_ return gr.update(visible=video_guide_outpainting_checkbox), video_guide_outpainting custom_resolutions = None -def get_resolution_choices(current_resolution_choice): +def get_resolution_choices(current_resolution_choice, model_resolutions= None): global custom_resolutions resolution_file = "resolutions.json" - if custom_resolutions == None and os.path.isfile(resolution_file) : + if model_resolutions is not None: + resolution_choices = model_resolutions + elif custom_resolutions == None and os.path.isfile(resolution_file) : with open(resolution_file, 'r', encoding='utf-8') as f: try: resolution_choices = json.load(f) @@ -6756,9 +6346,13 @@ def get_resolution_choices(current_resolution_choice): if current_resolution_choice == res: found = True break - if not found: - resolution_choices.append( (current_resolution_choice, current_resolution_choice )) - return resolution_choices + if not found: + if model_resolutions is None: + resolution_choices.append( (current_resolution_choice, current_resolution_choice )) + else: + current_resolution_choice = resolution_choices[0][1] + + return resolution_choices, current_resolution_choice group_thresholds = { "360p": 320 * 640, @@ -6795,7 +6389,10 @@ def group_resolutions(resolutions, selected_resolution): return available_groups, selected_group_resolutions, selected_group def change_resolution_group(state, selected_group): - resolution_choices = get_resolution_choices(None) + model_type = state["model_type"] + model_def = get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + resolution_choices, _ = get_resolution_choices(None, model_resolutions) group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] last_resolution_per_group = state["last_resolution_per_group"] @@ -6929,12 +6526,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) model_reference_image = model_def.get("reference_image", False) - no_steps_skipping = model_def.get("no_steps_skipping", False) + any_tea_cache = model_def.get("tea_cache", False) + any_mag_cache = model_def.get("mag_cache", False) recammaster = base_model_type in ["recam_1.3B"] vace = test_vace_module(base_model_type) phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] fantasy = base_model_type in ["fantasy"] - multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] + multitalk = model_def.get("multitalk_class", False) hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_video_custom = "hunyuan_video_custom" in model_filename @@ -6947,7 +6545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_enabled = test_any_sliding_window(model_type) multi_prompts_gen_type_value = ui_defaults.get("multi_prompts_gen_type_value",0) prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type_value, image_outputs) - any_video_source = True + any_video_source = False fps = get_model_fps(base_model_type) image_prompt_type_value = "" video_prompt_type_value = "" @@ -6955,6 +6553,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non any_end_image = False any_reference_image = False v2i_switch_supported = (vace or t2v) and not image_outputs + ti2v_2_2 = base_model_type in ["ti2v_2_2"] + image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 ) if not v2i_switch_supported and not image_outputs: image_mode_value = 0 @@ -6969,7 +6569,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non pass - with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column: + with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: if vace: image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value @@ -6980,14 +6580,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) model_mode = gr.Dropdown(visible = False) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + any_video_source = True - elif diffusion_forcing or ltxv: + elif diffusion_forcing or ltxv or ti2v_2_2: image_prompt_type_value= ui_defaults.get("image_prompt_type","T") # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")] if ltxv: image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - image_prompt_type_choices += [("Continue Video", "V")] + if sliding_window_enabled: + any_video_source = True + image_prompt_type_choices += [("Continue Video", "V")] image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) @@ -6998,7 +6601,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Images as ending points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - if ltxv: + if not diffusion_forcing: model_mode = gr.Dropdown( choices=[ ], value=None, @@ -7045,8 +6648,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type_choices = [("Start Video with Image", "S")] image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] if not hunyuan_i2v: + any_video_source = True image_prompt_type_choices += [("Continue Video", "V")] - + image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) any_start_image = True any_end_image = True @@ -7061,13 +6665,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_source = gr.Video(value=None, visible=False) else: video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - any_video_source = True else: image_prompt_type = gr.Radio(choices=[("", "")], value="") image_start = gr.Gallery(value=None) image_end = gr.Gallery(value=None) video_source = gr.Video(value=None, visible=False) - any_video_source = False model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) @@ -7147,7 +6749,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ], value= filter_letters(video_prompt_type_value, "NA"), visible= "V" in video_prompt_type_value, - label="Area Processed", scale = 2 + label="Area Processed", scale = 2, show_label= True, ) elif ltxv: video_prompt_type_video_mask = gr.Dropdown( @@ -7160,7 +6762,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ], value= filter_letters(video_prompt_type_value, "XNA"), visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value, - label="Area Processed", scale = 2 + label="Area Processed", scale = 2, show_label= True, ) else: video_prompt_type_video_mask = gr.Dropdown( @@ -7179,7 +6781,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ], value= filter_letters(video_prompt_type_value, "XYZWNA"), visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv, - label="Area Processed", scale = 2 + label="Area Processed", scale = 2, show_label= True, ) if t2v: video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False) @@ -7193,7 +6795,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ], value=filter_letters(video_prompt_type_value, "KFI"), visible = True, - label="Reference Images", scale = 2 + label="Reference Images", show_label= True, scale = 2 ) @@ -7327,7 +6929,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non else: label = "Max Resolution (Pixels will be reallocated depending on the output width / height ratio)" current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution - resolution_choices= get_resolution_choices(current_resolution_choice) + model_resolutions = model_def.get("resolutions", None) + resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) available_groups, selected_group_resolutions, selected_group = group_resolutions(resolution_choices, current_resolution_choice) resolution_group = gr.Dropdown( choices = available_groups, @@ -7366,29 +6969,29 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Tab("General"): with gr.Column(): seed = gr.Slider(-1, 999999999, value=ui_defaults.get("seed",-1), step=1, label="Seed (-1 for random)") + any_embedded_guidance = model_def.get("embedded_guidance", False) with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row: - guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) + guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=(fantasy or multitalk) and not no_guidance) - embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) + embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) - with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row2: - guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) - switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance) + with gr.Row(visible = model_def.get("guidance_max_phases",1) >1 and not (no_guidance and image_outputs)) as guidance_row2: + guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) + switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) + sample_solver_choices = model_def.get("sample_solvers", None) + with gr.Row(visible = sample_solver_choices is not None ) as sample_solver_row: + if sample_solver_choices is None: + sample_solver = gr.Dropdown( value="", choices=[ ("", ""), ], visible= False, label= "Sampler Solver / Scheduler" ) + else: + sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver", sample_solver_choices[0][1]), + choices= sample_solver_choices, visible= True, label= "Sampler Solver / Scheduler" + ) - with gr.Row(visible = get_model_family(model_type) == "wan" and not diffusion_forcing ) as sample_solver_row: - sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver",""), - choices=[ - ("unipc", ""), - ("euler", "euler"), - ("dpm++", "dpm++"), - ("flowmatch causvid", "causvid"), - ], visible= True, label= "Sampler Solver / Scheduler" - ) with gr.Row(visible = vace) as control_net_weights_row: control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) - negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v or flux or no_negative_prompt) ) + negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v or no_negative_prompt) ) with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") with gr.Row(): @@ -7415,18 +7018,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Activated Loras" ) loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) - with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs) and not no_steps_skipping) as speed_tab: + with gr.Tab("Steps Skipping", visible = any_tea_cache or any_mag_cache) as speed_tab: with gr.Column(): gr.Markdown("Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.") gr.Markdown("Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.") - + steps_skipping_choices = [("None", "")] + if any_tea_cache: steps_skipping_choices += [("Tea Cache", "tea")] + if any_mag_cache: steps_skipping_choices += [("Mag Cache", "mag")] skip_steps_cache_type = gr.Dropdown( - choices=[ - ("None", ""), - ("Tea Cache", "tea"), - ("Mag Cache", "mag"), - ], - value=ui_defaults.get("skip_steps_cache_type",""), + choices= steps_skipping_choices, + value="" if not (any_tea_cache or any_mag_cache) else ui_defaults.get("skip_steps_cache_type",""), visible=True, label="Skip Steps Cache Type" ) @@ -7516,9 +7117,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gr.Markdown("Add Custom Soundtrack to Video") audio_source = gr.Audio(value= ui_defaults.get("audio_source", None), type="filepath", label="Soundtrack", show_download_button= True) + any_skip_layer_guidance = model_def.get("skip_layer_guidance", False) + any_cfg_zero = model_def.get("cfg_zero", False) + any_cfg_star = model_def.get("cfg_star", False) + any_apg = model_def.get("adaptive_projected_guidance", False) - with gr.Tab("Quality", visible = not (ltxv and no_negative_prompt or flux)) as quality_tab: - with gr.Column(visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_custom or hunyuan_video_avatar or ltxv) ) as skip_layer_guidance_row: + with gr.Tab("Quality", visible = vace and image_outputs or any_skip_layer_guidance or any_cfg_zero or any_cfg_star or any_apg ) as quality_tab: + with gr.Column(visible = any_skip_layer_guidance ) as skip_layer_guidance_row: gr.Markdown("Skip Layer Guidance (improves video quality, requires guidance > 1)") with gr.Row(): slg_switch = gr.Dropdown( @@ -7544,7 +7149,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start") slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end") - with gr.Column(visible= not no_negative_prompt and (vace or multitalk or t2v or test_class_i2v(model_type) or ltxv) ) as apg_col: + with gr.Column(visible= any_apg ) as apg_col: gr.Markdown("Correct Progressive Color Saturation during long Video Generations") apg_switch = gr.Dropdown( choices=[ @@ -7557,7 +7162,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Adaptive Projected Guidance (requires Guidance > 1) " ) - with gr.Column(visible = not ltxv) as cfg_free_guidance_col: + with gr.Column(visible = any_cfg_star) as cfg_free_guidance_col: gr.Markdown("Classifier-Free Guidance Zero Star, better adherence to Text Prompt") cfg_star_switch = gr.Dropdown( choices=[ @@ -7570,7 +7175,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Classifier-Free Guidance Star (requires Guidance > 1)" ) with gr.Row(): - cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_avatar or hunyuan_i2v or hunyuan_video_custom )) + cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero) with gr.Column(visible = vace and image_outputs) as min_frames_if_references_col: gr.Markdown("If using Reference Images, generating a single Frame alone may not be sufficient to preserve Identity") @@ -8259,10 +7864,6 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice label="User Interface Theme. You will need to restart the App the see new Theme." ) - save_path_choice = gr.Textbox( - label="Output Folder for Generated Videos (need to restart app to be taken into account)", - value=server_config.get("save_path", save_path) - ) with gr.Tab("Performance"): @@ -8386,6 +7987,53 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice label="MMAudio (if enabled, 10 GB of extra models will be downloaded)" ) + with gr.Tab("Outputs"): + + video_output_codec_choice = gr.Dropdown( + choices=[ + ("x265 Balanced Quality (CRF 28)", 'libx265_28'), + ("x264 Balanced Quality (Level 8)", 'libx264_8'), + ("x265 High Quality (CRF 8)", 'libx265_8'), + ("x264 High Quality (Level 10)", 'libx264_10'), + ("x264 Lossless", 'libx264_lossless'), + ], + value=server_config.get("video_output_codec", "libx264_8"), + label="Video Codec to use" + ) + + image_output_codec_choice = gr.Dropdown( + choices=[ + ("JPEG Quality 85", 'jpeg_85'), + ("WEBP Quality 85", 'webp_85'), + ("JPEG Quality 95", 'jpeg_95'), + ("WEBP Quality 95", 'webp_95'), + ("WEBP Lossless", 'webp_lossless'), + ("PNG Lossless", 'png'), + ], + value=server_config.get("image_output_codec", "jpeg_95"), + label="Image Codec to use" + ) + + audio_output_codec_choice = gr.Dropdown( + choices=[ + ("AAC 128 kbit", 'aac_128'), + ], + value=server_config.get("audio_output_codec", "aac_128"), + visible = False, + label="Audio Codec to use" + ) + + video_save_path_choice = gr.Textbox( + label="Output Folder for Generated Videos (need to restart app to be taken into account)", + value=server_config.get("save_path", save_path) + ) + + image_save_path_choice = gr.Textbox( + label="Output Folder for Generated Images (need to restart app to be taken into account)", + value=server_config.get("image_save_path", image_save_path) + ) + + with gr.Tab("Notifications"): gr.Markdown("### Notification Settings") notification_sound_enabled_choice = gr.Dropdown( @@ -8418,7 +8066,8 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice text_encoder_quantization_choice, VAE_precision_choice, mixed_precision_choice, - save_path_choice, + video_save_path_choice, + image_save_path_choice, attention_choice, compile_choice, profile_choice, @@ -8438,6 +8087,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice notification_sound_volume_choice, max_frames_multiplier_choice, display_stats_choice, + video_output_codec_choice, + image_output_codec_choice, + audio_output_codec_choice, resolution, ], outputs= [msg , header, model_family, model_choice, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col] @@ -9002,7 +8654,7 @@ def create_ui(): } """ if server_config.get("display_stats", 0) == 1: - from wan.utils.stats import SystemStatsApp + from shared.utils.stats import SystemStatsApp stats_app = SystemStatsApp() else: stats_app = None @@ -9036,7 +8688,7 @@ def create_ui(): with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, video_guide, image_guide, video_mask, image_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, server_config, video_guide, image_guide, video_mask, image_mask, image_refs) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) @@ -9072,5 +8724,4 @@ if __name__ == "__main__": else: url = "http://" + server_name webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True) - demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=[save_path]) - + demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path})) \ No newline at end of file