diff --git a/.gitignore b/.gitignore index d95eb33..a84122c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ *.pth *.ckpt *.safetensors -*.json +#*.json # *.txt *.backup *.pkl @@ -36,6 +36,7 @@ Wan2.1-T2V-1.3B/ Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-720P/ outputs/ +outputs2/ gradio_outputs/ ckpts/ loras/ diff --git a/LICENSE.txt b/LICENSE.txt index ada4a22..1262c13 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,17 +1,46 @@ -FREE for Non Commercial USE +WanGP NON-COMMERCIAL EVALUATION LICENSE 1.0 -You are free to: -- Share — copy and redistribute the material in any medium or format -- Adapt — remix, transform, and build upon the material -The licensor cannot revoke these freedoms as long as you follow the license terms. +Definitions +1.1 “Software” means the source code, binaries, libraries, utilities and UI released under this license. +1.2 “Output” means images, videos or other media produced by running the Software. +1.3 “Commercial Use” means: +a) selling, sublicensing, renting, leasing, or otherwise distributing the Software, in whole or in part, for a fee or other consideration; or +b) offering the Software (or any derivative) as part of a paid product or hosted service; or +c) using the Software (or any derivative) to provide cloud-based or backend services, where end users access or pay for those services. -Under the following terms: -- Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use. -NonCommercial — You may not use the material for commercial purposes . +License Grant +Subject to Section 3: +a) You are granted a worldwide, non-exclusive, royalty-free, revocable license to use, reproduce, modify and distribute the Software for non-commercial purposes only. +b) You are granted a worldwide, non-exclusive, royalty-free, irrevocable license to use, reproduce, modify and distribute the Output for any purpose, including commercial sale, provided that any commercial distribution of the Output includes a clear notice that the Output was produced (in whole or in part) using WanGP, along with a hyperlink to the WanGP application’s About tab or repository. -- No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits. -Notices: +Restrictions +3.1 You MAY NOT distribute, sublicense or otherwise make available the Software (or any derivative) for Commercial Use. +3.2 You MAY sell, license or otherwise commercially exploit the Output without restriction. +3.3 If you wish to use the Software for Commercial Use, you must obtain a separate commercial license from the Licensor. -- You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation . +Third-Party Components 4.1 The Software includes components licensed under various open-source licenses (e.g., Apache 2.0, MIT, BSD). 4.2 You must comply with all applicable terms of those third-party licenses, including preservation of copyright notices, inclusion of required license texts, and patent-grant provisions. 4.3 You can find the full text of each third-party license via the “About” tab in the WanGP application, which provides links to their original GitHub repositories. + +Attribution +5.1 You must give appropriate credit by including: +• a copy of this license (or a link to it), and +• a notice that your use is based on “WanGP”. +5.2 You may do so in any reasonable manner, but not in any way that suggests the Licensor endorses you or your use. + +Disclaimer of Warranty & Liability +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE. + +Commercial Licensing The Licensor may offer commercial licenses for the Software, which grant rights to use the Software for Commercial Use. Please contact [deepbeepmeep@yahoo.com] for terms and pricing. + +Effective Date & Previous Versions +8.1 This license is effective as of the date the LICENSE file is updated in the WanGP repository. +8.2 Any copies of the Software obtained under prior license terms before this Effective Date remain governed by those prior terms; such granted rights are irrevocable. +8.3 Use of the Software after the release of any subsequent version by the Licensor is subject to the terms of the then-current license, unless a separate agreement is in place. + +Acceptable Use / Moral Clause +9.1 You MAY NOT use the Software or the Output to facilitate or produce content that is illegal, harmful, violent, harassing, defamatory, fraudulent, or otherwise violates applicable laws or fundamental human rights. +9.2 You MAY NOT deploy the Software or Output in contexts that promote hate speech, extremist ideology, human rights abuses, or other actions that could foreseeably cause significant harm to individuals or groups. +9.3 The Licensor reserves the right to terminate the rights granted under this license if a licensee materially breaches this Acceptable Use clause. + +END OF LICENSE -No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material. \ No newline at end of file diff --git a/README.md b/README.md index fb04ce1..dff2873 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,157 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep -## 🔥 Latest Updates +## 🔥 Latest Updates : +### August 29 2025: WanGP v8.2 - Here Goes Your Weekend + +- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) + +- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn + +- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis + +### August 24 2025: WanGP v8.1 - the RAM Liberator + +- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP +- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ +If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. +- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps + +### August 21 2025: WanGP v8.01 - the killer of seven + +- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. +- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. +- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work +- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading +- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. + +*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* +### August 12 2025: WanGP v7.7777 - Lucky Day(s) + +This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! + +Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. + +Support: +- Video: x264, x264 lossless, x265 +- Images: jpeg, png, webp, wbp lossless +Generation Settings are stored in each of the above regardless of the format (that was the hard part). + +Also you can now choose different output directories for images and videos. + +unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. +*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* +*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) + + +### August 10 2025: WanGP v7.76 - Faster than the VAE ... +We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... +Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. + +*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* + +### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 +Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. + +Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! + +### August 8 2025: WanGP v7.73 - Qwen Rebirth +Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. + +As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) + +Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. + +*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* + + +### August 6 2025: WanGP v7.71 - Picky, picky + +This release comes with two new models : +- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals +- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) + +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. + +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* + + +### August 4 2025: WanGP v7.6 - Remuxed + +With this new version you won't have any excuse if there is no sound in your video. + +*Continue Video* now works with any video that has already some sound (hint: Multitalk ). + +Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. + +As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. + +For instance: +- first video part: use Multitalk with two people speaking +- second video part: you apply your own soundtrack which will gently follow the multitalk conversation +- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio + +To multiply the combinations I have also implemented *Continue Video* with the various image2video models. + +Also: +- End Frame support added for LTX Video models +- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides +- Flux Krea Dev support + +### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 +Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... + +Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. + +I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... + +And this time I really removed Vace Cocktail Light which gave a blurry vision. + +### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview +Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. + +So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. + +However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! + +Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation + +Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... + +7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. + +### July 27 2025: WanGP v7.3 : Interlude +While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. + +### July 26 2025: WanGP v7.2 : Ode to Vace +I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. + +Here are some new Vace improvements: +- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. +- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). +- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. +- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. + +Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. + + +### July 21 2025: WanGP v7.12 +- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. + +- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) + +- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. + +- LTX IC-Lora support: these are special Loras that consumes a conditional image or video +Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. + +- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) + +- Easier way to select video resolution + + ### July 15 2025: WanGP v7.0 is an AI Powered Photoshop This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : - Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB @@ -86,84 +236,6 @@ Taking care of your life is not enough, you want new stuff to play with ? **If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** -### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ? -- Multithreaded preprocessing when possible for faster generations -- Multithreaded frames Lanczos Upsampling as a bonus -- A new Vace preprocessor : *Flow* to extract fluid motion -- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character. -- Injected Frames Outpainting, in case you missed it in WanGP 6.21 - -Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated. - - -### June 19 2025: WanGP v6.2, Vace even more Powercharged -👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: -- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time -- More processing can combined at the same time (for instance the depth process can be applied outside the mask) -- Upgraded the depth extractor to Depth Anything 2 which is much more detailed - -As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server. -### June 17 2025: WanGP v6.1, Vace Powercharged -👋 Lots of improvements for Vace the Mother of all Models: -- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask -- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ... -- view these modified masks directly inside WanGP during the video generation to check they are really as expected -- multiple frames injections: multiples frames can be injected at any location of the video -- expand past videos in on click: just select one generated video to expand it - -Of course all these new stuff work on all Vace finetunes (including Vace Fusionix). - -Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary. - -### June 12 2025: WanGP v6.0 -👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. - -To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): -- *Fast Hunyuan Video* : generate model t2v in only 6 steps -- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps -- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps - -One more thing... - -The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? - -You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... - -Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. - -### June 11 2025: WanGP v5.5 -👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ -*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... - - -### June 6 2025: WanGP v5.41 -👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\ -You will need to do a *pip install -r requirements.txt* - -### June 6 2025: WanGP v5.4 -👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\ -Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\ -Also many thanks to Reevoy24 for his repackaging / completing the documentation - -### May 28 2025: WanGP v5.31 -👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets. -VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options. - -### May 26, 2025: WanGP v5.3 -👋 Settings management revolution! Now you can: -- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration -- Drag & drop videos to automatically extract their settings metadata -- Export/import settings as JSON files for easy sharing and backup - -### May 20, 2025: WanGP v5.2 -👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid. - -### May 18, 2025: WanGP v5.1 -👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute! - -### May 17, 2025: WanGP v5.0 -👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer. - See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 📋 Table of Contents @@ -211,7 +283,7 @@ git clone https://github.com/deepbeepmeep/Wan2GP.git cd Wan2GP conda create -n wan2gp python=3.10.9 conda activate wan2gp -pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124 +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 pip install -r requirements.txt ``` @@ -229,6 +301,7 @@ git pull pip install -r requirements.txt ``` + ## 📦 Installation For detailed installation instructions for different GPU generations: @@ -251,6 +324,12 @@ For detailed installation instructions for different GPU generations: - **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history - **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions +## 📚 Video Guides +- Nice Video that explain how to use Vace:\ +https://www.youtube.com/watch?v=FMo9oN2EAvE +- Another Vace guide:\ +https://www.youtube.com/watch?v=T5jNiEhf9xk + ## 🔗 Related Projects ### Other Models for the GPU Poor diff --git a/assets/comp_effic.png b/assets/comp_effic.png deleted file mode 100644 index ea0e3b2..0000000 Binary files a/assets/comp_effic.png and /dev/null differ diff --git a/assets/data_for_diff_stage.jpg b/assets/data_for_diff_stage.jpg deleted file mode 100644 index af98046..0000000 Binary files a/assets/data_for_diff_stage.jpg and /dev/null differ diff --git a/assets/i2v_res.png b/assets/i2v_res.png deleted file mode 100644 index fb13d61..0000000 Binary files a/assets/i2v_res.png and /dev/null differ diff --git a/assets/logo.png b/assets/logo.png deleted file mode 100644 index 0c55854..0000000 Binary files a/assets/logo.png and /dev/null differ diff --git a/assets/t2v_res.jpg b/assets/t2v_res.jpg deleted file mode 100644 index 6a58388..0000000 Binary files a/assets/t2v_res.jpg and /dev/null differ diff --git a/assets/vben_vs_sota.png b/assets/vben_vs_sota.png deleted file mode 100644 index 4f09de6..0000000 Binary files a/assets/vben_vs_sota.png and /dev/null differ diff --git a/assets/video_dit_arch.jpg b/assets/video_dit_arch.jpg deleted file mode 100644 index a13e499..0000000 Binary files a/assets/video_dit_arch.jpg and /dev/null differ diff --git a/assets/video_vae_res.jpg b/assets/video_vae_res.jpg deleted file mode 100644 index e1bfb11..0000000 Binary files a/assets/video_vae_res.jpg and /dev/null differ diff --git a/configs/i2v_2_2.json b/configs/i2v_2_2.json new file mode 100644 index 0000000..a64a868 --- /dev/null +++ b/configs/i2v_2_2.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v2_2", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512 +} \ No newline at end of file diff --git a/configs/i2v_2_2_multitalk.json b/configs/i2v_2_2_multitalk.json new file mode 100644 index 0000000..7206fdc --- /dev/null +++ b/configs/i2v_2_2_multitalk.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v2_2", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "multitalk_output_dim": 768 +} \ No newline at end of file diff --git a/configs/i2v_720p.json b/configs/infinitetalk.json similarity index 82% rename from configs/i2v_720p.json rename to configs/infinitetalk.json index f5a12b2..2724759 100644 --- a/configs/i2v_720p.json +++ b/configs/infinitetalk.json @@ -10,5 +10,6 @@ "num_heads": 40, "num_layers": 40, "out_dim": 16, - "text_len": 512 + "text_len": 512, + "multitalk_output_dim": 768 } diff --git a/configs/qwen_image_20B.json b/configs/qwen_image_20B.json new file mode 100644 index 0000000..4bff1e5 --- /dev/null +++ b/configs/qwen_image_20B.json @@ -0,0 +1,18 @@ +{ + "_class_name": "QwenImageTransformer2DModel", + "_diffusers_version": "0.34.0.dev0", + "attention_head_dim": 128, + "axes_dims_rope": [ + 16, + 56, + 56 + ], + "guidance_embeds": false, + "in_channels": 64, + "joint_attention_dim": 3584, + "num_attention_heads": 24, + "num_layers": 60, + "out_channels": 16, + "patch_size": 2, + "pooled_projection_dim": 768 +} diff --git a/configs/standin.json b/configs/standin.json new file mode 100644 index 0000000..1c7027a --- /dev/null +++ b/configs/standin.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "standin": true +} diff --git a/configs/ti2v_2_2.json b/configs/ti2v_2_2.json new file mode 100644 index 0000000..d58edcc --- /dev/null +++ b/configs/ti2v_2_2.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 48, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/configs/vace_standin_14B.json b/configs/vace_standin_14B.json new file mode 100644 index 0000000..bf83e98 --- /dev/null +++ b/configs/vace_standin_14B.json @@ -0,0 +1,17 @@ +{ + "_class_name": "VaceWanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_dim": 96, + "standin": true +} diff --git a/defaults/fantasy.json b/defaults/fantasy.json index dbab1b2..3917cab 100644 --- a/defaults/fantasy.json +++ b/defaults/fantasy.json @@ -3,10 +3,9 @@ { "name": "Fantasy Talking 720p", "architecture" : "fantasy", - "modules": ["fantasy"], + "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_fantasy_speaking_14B_bf16.safetensors"]], "description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.", - "URLs": "i2v_720p", - "teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + "URLs": "i2v_720p" }, "resolution": "1280x720" } diff --git a/defaults/flux.json b/defaults/flux.json new file mode 100644 index 0000000..87bab0f --- /dev/null +++ b/defaults/flux.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Flux 1 Dev 12B", + "architecture": "flux", + "description": "FLUX.1 Dev is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_quanto_bf16_int8.safetensors" + ], + "image_outputs": true, + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_chroma.json b/defaults/flux_chroma.json new file mode 100644 index 0000000..7ffc426 --- /dev/null +++ b/defaults/flux_chroma.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Flux 1 Chroma 1 HD 8.9B", + "architecture": "flux", + "description": "FLUX.1 Chroma is a 8.9 billion parameters model. As a base model, Chroma1 is intentionally designed to be an excellent starting point for finetuning. It provides a strong, neutral foundation for developers, researchers, and artists to create specialized models..", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-chroma_hd_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-chroma_hd_quanto_bf16_int8.safetensors" + ], + "image_outputs": true, + "flux-model": "flux-chroma" + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "guidance_scale": 3.0, + "num_inference_steps": 20, + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index 14006b1..8945918 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -1,16 +1,19 @@ { "model": { - "name": "Flux Dev Kontext 12B", - "architecture": "flux_dev_kontext", - "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.", + "name": "Flux 1 Dev Kontext 12B", + "architecture": "flux", + "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image and the output dimensions may not match the dimensions of the input image.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" - ] + ], + "image_outputs": true, + "reference_image": true, + "flux-model": "flux-dev-kontext" }, "prompt": "add a hat", "resolution": "1280x720", - "video_length": 1 + "batch_size": 1 } \ No newline at end of file diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json new file mode 100644 index 0000000..ab5ac54 --- /dev/null +++ b/defaults/flux_dev_uso.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Flux 1 Dev USO 12B", + "architecture": "flux", + "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).", + "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], + "URLs": "flux", + "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], + "image_outputs": true, + "reference_image": true, + "flux-model": "flux-dev-uso" + }, + "prompt": "add a hat", + "embedded_guidance_scale": 4, + "resolution": "1024x1024", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/flux_krea.json b/defaults/flux_krea.json new file mode 100644 index 0000000..3caba1a --- /dev/null +++ b/defaults/flux_krea.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Flux 1 Krea Dev 12B", + "architecture": "flux", + "description": "Cutting-edge output quality, with a focus on aesthetic photography..", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_quanto_bf16_int8.safetensors" + ], + "image_outputs": true, + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_schnell.json b/defaults/flux_schnell.json new file mode 100644 index 0000000..d7abcde --- /dev/null +++ b/defaults/flux_schnell.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Flux 1 Schnell 12B", + "architecture": "flux", + "description": "FLUX.1 Schnell is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. As a distilled model it requires fewer denoising steps.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-schnell_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-schnell_quanto_bf16_int8.safetensors" + ], + "image_outputs": true, + "flux-model": "flux-schnell" + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "num_inference_steps": 10, + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/hunyuan.json b/defaults/hunyuan.json index 5012c02..a6ba832 100644 --- a/defaults/hunyuan.json +++ b/defaults/hunyuan.json @@ -1,11 +1,11 @@ { "model": { - "name": "Hunyuan Video text2video 720p 13B", + "name": "Hunyuan Video Text2video 720p 13B", "architecture" : "hunyuan", "description": "Probably the best text 2 video model available.", "URLs": [ - "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_quanto_int8.safetensors" ] } diff --git a/defaults/hunyuan_i2v.json b/defaults/hunyuan_i2v.json index 400a6a3..44722da 100644 --- a/defaults/hunyuan_i2v.json +++ b/defaults/hunyuan_i2v.json @@ -1,7 +1,7 @@ { "model": { - "name": "Hunyuan Video image2video 720p 13B", + "name": "Hunyuan Video Image2video 720p 13B", "architecture" : "hunyuan_i2v", "description": "A good looking image 2 video model, but not so good in prompt adherence.", "URLs": [ diff --git a/defaults/hunyuan_t2v_accvideo.json b/defaults/hunyuan_t2v_accvideo.json index 2164744..23309d0 100644 --- a/defaults/hunyuan_t2v_accvideo.json +++ b/defaults/hunyuan_t2v_accvideo.json @@ -1,6 +1,6 @@ { "model": { - "name": "Hunyuan AccVideo 720p 13B", + "name": "Hunyuan Video AccVideo 720p 13B", "architecture": "hunyuan", "description": " AccVideo is a novel efficient distillation method to accelerate video diffusion models with synthetic datset. Our method is 8.5x faster than HunyuanVideo.", "URLs": [ diff --git a/defaults/hunyuan_t2v_fast.json b/defaults/hunyuan_t2v_fast.json index a7721fd..acba28e 100644 --- a/defaults/hunyuan_t2v_fast.json +++ b/defaults/hunyuan_t2v_fast.json @@ -1,6 +1,6 @@ { "model": { - "name": "Hunyuan Fast Video 720p 13B", + "name": "Hunyuan Video FastHunyuan 720p 13B", "architecture": "hunyuan", "description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.", "URLs": [ diff --git a/defaults/i2v.json b/defaults/i2v.json index 33a4d55..ba10691 100644 --- a/defaults/i2v.json +++ b/defaults/i2v.json @@ -1,7 +1,7 @@ { "model": { - "name": "Wan2.1 image2video 480p 14B", + "name": "Wan2.1 Image2video 480p 14B", "architecture" : "i2v", "description": "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)", "URLs": [ diff --git a/defaults/i2v_2_2.json b/defaults/i2v_2_2.json new file mode 100644 index 0000000..a032333 --- /dev/null +++ b/defaults/i2v_2_2.json @@ -0,0 +1,25 @@ +{ + "model": + { + "name": "Wan2.2 Image2video 14B", + "architecture" : "i2v_2_2", + "description": "Wan 2.2 Image 2 Video model. Contrary to the Wan Image2video 2.1 this model is structurally close to the t2v model. You will need consequently to store Loras for this model in the t2v Lora Folder.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_high_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_high_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_high_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_low_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_image2video_14B_low_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "switch_threshold" : 900, + "guidance_scale" : 3.5, + "guidance2_scale" : 3.5, + "flow_shift" : 5 + +} \ No newline at end of file diff --git a/defaults/i2v_2_2_multitalk.json b/defaults/i2v_2_2_multitalk.json new file mode 100644 index 0000000..9326469 --- /dev/null +++ b/defaults/i2v_2_2_multitalk.json @@ -0,0 +1,18 @@ +{ + "model": + { + "name": "Wan2.2 Multitalk 14B", + "architecture" : "i2v_2_2_multitalk", + "description": "The Multitalk module of Wan 2.1 has been combined with the Wan 2.2 image 2 video. It lets you have up to two people have a conversation.", + "modules": ["multitalk"], + "URLs": "i2v_2_2", + "URLs2": "i2v_2_2", + "group": "wan2_2", + "visible": false + }, + "switch_threshold" : 900, + "guidance_scale" : 3.5, + "guidance2_scale" : 3.5, + "flow_shift" : 5 + +} \ No newline at end of file diff --git a/defaults/i2v_720p.json b/defaults/i2v_720p.json index 90523de..844aab9 100644 --- a/defaults/i2v_720p.json +++ b/defaults/i2v_720p.json @@ -1,7 +1,7 @@ { "model": { - "name": "Wan2.1 image2video 720p 14B", + "name": "Wan2.1 Image2video 720p 14B", "architecture" : "i2v", "description": "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well).", "URLs": [ diff --git a/defaults/i2v_fusionix.json b/defaults/i2v_fusionix.json index ffbb0a1..851d6cc 100644 --- a/defaults/i2v_fusionix.json +++ b/defaults/i2v_fusionix.json @@ -1,7 +1,7 @@ { "model": { - "name": "Wan2.1 image2video 480p FusioniX 14B", + "name": "Wan2.1 Image2video 480p FusioniX 14B", "architecture" : "i2v", "description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", "URLs": "i2v", diff --git a/defaults/infinitetalk.json b/defaults/infinitetalk.json new file mode 100644 index 0000000..7fddea9 --- /dev/null +++ b/defaults/infinitetalk.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Infinitetalk Single Speaker 480p", + "architecture": "infinitetalk", + "modules": [ + [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_single_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_single_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_single_14B_quanto_mfp16_int8.safetensors" + ] + ], + "description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the single speaker version. Sliding Window size must be 81 frames to get smooth transitions between shots.", + "one_speaker_only": true, + "URLs": "i2v" + } +} \ No newline at end of file diff --git a/defaults/infinitetalk_multi.json b/defaults/infinitetalk_multi.json new file mode 100644 index 0000000..97251e8 --- /dev/null +++ b/defaults/infinitetalk_multi.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Infinitetalk Multi Speakers 480p", + "architecture": "infinitetalk", + "modules": [ + [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_multi_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_multi_14B_quanto_mfp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_multi_14B_quanto_mbf16_int8.safetensors" + ] + ], + "description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the multi speakers version.Sliding Window size must be 81 frames to get smooth transitions between shots", + "multi_speakers_only": true, + "URLs": "i2v" + } +} \ No newline at end of file diff --git a/defaults/ltxv_13B.json b/defaults/ltxv_13B.json index 7e45e9a..639442e 100644 --- a/defaults/ltxv_13B.json +++ b/defaults/ltxv_13B.json @@ -1,14 +1,19 @@ -{ +{ "model": { - "name": "LTX Video 0.9.7 13B", + "name": "LTX Video 0.9.8 13B", "architecture" : "ltxv_13B", - "description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", + "description": "LTX Video is a fast model that can be used to generate very very long videos (up to 1800 frames !).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.8-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", "URLs": [ - "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_bf16.safetensors", - "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors" + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_dev_quanto_bf16_int8.safetensors" ], - "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml" + "preload_URLs" : [ + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-pose-control-diffusers.safetensors", + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors", + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors" + ], + "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml" }, "num_inference_steps": 30 } diff --git a/defaults/ltxv_distilled.json b/defaults/ltxv_distilled.json index 256ea81..c570057 100644 --- a/defaults/ltxv_distilled.json +++ b/defaults/ltxv_distilled.json @@ -1,14 +1,15 @@ { "model": { - "name": "LTX Video 0.9.7 Distilled 13B", + "name": "LTX Video 0.9.8 Distilled 13B", "architecture" : "ltxv_13B", - "description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", - "URLs": "ltxv_13B", - "loras": ["https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"], - "loras_multipliers": [ 1 ], - "lock_inference_steps": true, - "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml" + "description": "LTX Video is a fast model that can be used to generate very long videos (up to 1800 frames !).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors" + ], + "preload_URLs" : "ltxv_13B", + "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml" }, "num_inference_steps": 6 } diff --git a/defaults/multitalk.json b/defaults/multitalk.json index 9c389d5..1133657 100644 --- a/defaults/multitalk.json +++ b/defaults/multitalk.json @@ -3,7 +3,11 @@ { "name": "Multitalk 480p", "architecture" : "multitalk", - "modules": ["multitalk"], + "modules": [ + ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_multitalk_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] + ], "description": "The Multitalk model corresponds to the original Wan image 2 video model combined with the Multitalk module. It lets you have up to two people have a conversation.", "URLs": "i2v", "teacache_coefficients" : [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] diff --git a/defaults/qwen_image_20B.json b/defaults/qwen_image_20B.json new file mode 100644 index 0000000..27bee20 --- /dev/null +++ b/defaults/qwen_image_20B.json @@ -0,0 +1,21 @@ +{ + "model": { + "name": "Qwen Image 20B", + "architecture": "qwen_image_20B", + "description": "Qwen Image is generative model that will generate very high quality images. It is one of the few models capable to generate in the image very long texts.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_quanto_bf16_int8.safetensors" + ], + "xresolutions": [ ["1328x1328 (1:1)", "1328x1328"], + ["1664x928 (16:9)", "1664x928"], + ["928x1664 (9:16)", "928x1664"], + ["1472x1140 (4:3)", "1472x1140"], + ["1140x1472 (3:4)", "1140x1472"]], + "attention": {"<89" : "sdpa"}, + "image_outputs": true + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json new file mode 100644 index 0000000..2b24c72 --- /dev/null +++ b/defaults/qwen_image_edit_20B.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Qwen Image Edit 20B", + "architecture": "qwen_image_edit_20B", + "description": "Qwen Image Edit is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. Use it to edit a Subject or combine multiple Subjects. ", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors" + ], + "attention": { + "<89": "sdpa" + }, + "reference_image": true, + "image_outputs": true + }, + "prompt": "add a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/standin.json b/defaults/standin.json new file mode 100644 index 0000000..1b5e324 --- /dev/null +++ b/defaults/standin.json @@ -0,0 +1,10 @@ +{ + "model": + { + "name": "Wan2.1 Standin 14B", + "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]], + "architecture" : "standin", + "description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of person face to transfer this person in the Video.", + "URLs": "t2v" + } +} \ No newline at end of file diff --git a/defaults/t2i.json b/defaults/t2i.json deleted file mode 100644 index f49f426..0000000 --- a/defaults/t2i.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "model": { - "name": "Wan2.1 text2image 14B", - "architecture": "t2v", - "description": "The original Wan Text 2 Video model configured to generate an image instead of a video.", - "image_outputs": true, - "URLs": "t2v" - }, - "video_length": 1, - "resolution": "1280x720" -} - - \ No newline at end of file diff --git a/defaults/t2v.json b/defaults/t2v.json index 2ab946a..ef7f240 100644 --- a/defaults/t2v.json +++ b/defaults/t2v.json @@ -1,7 +1,7 @@ { "model": { - "name": "Wan2.1 text2video 14B", + "name": "Wan2.1 Text2video 14B", "architecture" : "t2v", "description": "The original Wan Text 2 Video model. Most other models have been built on top of it", "URLs": [ diff --git a/defaults/t2v_1.3B.json b/defaults/t2v_1.3B.json index 859304f..ca88bd9 100644 --- a/defaults/t2v_1.3B.json +++ b/defaults/t2v_1.3B.json @@ -1,11 +1,11 @@ { "model": { - "name": "Wan2.1 text2video 1.3B", + "name": "Wan2.1 Text2video 1.3B", "architecture" : "t2v_1.3B", "description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it", "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_bf16.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_mbf16.safetensors" ] } } \ No newline at end of file diff --git a/defaults/t2v_2_2.json b/defaults/t2v_2_2.json new file mode 100644 index 0000000..122eeb9 --- /dev/null +++ b/defaults/t2v_2_2.json @@ -0,0 +1,25 @@ +{ + "model": + { + "name": "Wan2.2 Text2video 14B", + "architecture" : "t2v", + "description": "Wan 2.2 Text 2 Video model", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "switch_threshold" : 875, + "guidance_scale" : 4, + "guidance2_scale" : 3, + "flow_shift" : 12 + +} \ No newline at end of file diff --git a/defaults/t2v_fusionix.json b/defaults/t2v_fusionix.json index c27d1de..6ecdf0c 100644 --- a/defaults/t2v_fusionix.json +++ b/defaults/t2v_fusionix.json @@ -1,7 +1,7 @@ { "model": { - "name": "Wan2.1 text2video FusioniX 14B", + "name": "Wan2.1 Text2video FusioniX 14B", "architecture" : "t2v", "description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", "URLs": [ diff --git a/defaults/t2v_sf.json b/defaults/t2v_sf.json index 2343237..2131413 100644 --- a/defaults/t2v_sf.json +++ b/defaults/t2v_sf.json @@ -1,6 +1,6 @@ { "model": { - "name": "Wan2.1 text2video Self-Forcing 14B", + "name": "Wan2.1 Text2video Self-Forcing 14B", "architecture": "t2v", "description": "This model is an advanced text-to-video generation model. This approach allows the model to generate videos with significantly fewer inference steps (4 or 8 steps) and without classifier-free guidance, substantially reducing video generation time while maintaining high quality outputs.", "URLs": [ diff --git a/defaults/ti2v_2_2.json b/defaults/ti2v_2_2.json new file mode 100644 index 0000000..ac329fa --- /dev/null +++ b/defaults/ti2v_2_2.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Wan2.2 TextImage2video 5B", + "architecture": "ti2v_2_2", + "description": "Wan 2.2 Text 2 Video model 5B", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_quanto_mbf16_int8.safetensors" + ], + "group": "wan2_2" + }, + "video_length": 121, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 50, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json new file mode 100644 index 0000000..064c2b4 --- /dev/null +++ b/defaults/ti2v_2_2_fastwan.json @@ -0,0 +1,15 @@ +{ + "model": { + "name": "Wan2.2 FastWan TextImage2video 5B", + "architecture": "ti2v_2_2", + "description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution", + "URLs": "ti2v_2_2", + "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], + "group": "wan2_2" + }, + "video_length": 121, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 3, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/vace_1.3B.json b/defaults/vace_1.3B.json index 716fbd0..406b9a7 100644 --- a/defaults/vace_1.3B.json +++ b/defaults/vace_1.3B.json @@ -3,9 +3,10 @@ { "name": "Vace ControlNet 1.3B", "architecture" : "vace_1.3B", - "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", - "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1.3B_mbf16.safetensors" - ] + "modules": [ + ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1_3B_module.safetensors"] + ], + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", + "URLs": "t2v_1.3B" } } \ No newline at end of file diff --git a/defaults/vace_14B.json b/defaults/vace_14B.json index 139bad4..b639664 100644 --- a/defaults/vace_14B.json +++ b/defaults/vace_14B.json @@ -3,7 +3,9 @@ "name": "Vace ControlNet 14B", "architecture": "vace_14B", "modules": [ - "vace_14B" + ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_14B_module_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"] ], "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", "URLs": "t2v" diff --git a/defaults/vace_14B_cocktail.json b/defaults/vace_14B_cocktail.json new file mode 100644 index 0000000..87f2b78 --- /dev/null +++ b/defaults/vace_14B_cocktail.json @@ -0,0 +1,21 @@ +{ + "model": { + "name": "Vace Cocktail 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This model has been created on the fly using the Wan text 2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.", + "URLs": "t2v", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [1, 0.5, 0.5, 0.5] + }, + "num_inference_steps": 10, + "guidance_scale": 1, + "flow_shift": 2 +} \ No newline at end of file diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json new file mode 100644 index 0000000..f821a9e --- /dev/null +++ b/defaults/vace_14B_cocktail_2_2.json @@ -0,0 +1,26 @@ +{ + "model": { + "name": "Wan2.2 Vace Experimental Cocktail 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. There is so far only PARTIAL support of Vace 2.1 which is currently used.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [1, 0.2, 0.5, 0.5], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold" : 875 +} \ No newline at end of file diff --git a/defaults/vace_14B_fusionix_t2i.json b/defaults/vace_14B_fusionix_t2i.json deleted file mode 100644 index 75fbf42..0000000 --- a/defaults/vace_14B_fusionix_t2i.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "model": { - "name": "Vace FusioniX image2image 14B", - "architecture": "vace_14B", - "modules": [ - "vace_14B" - ], - "image_outputs": true, - "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", - "URLs": "t2v_fusionix" - }, - "resolution": "1280x720", - "guidance_scale": 1, - "num_inference_steps": 10, - "video_length": 1 -} \ No newline at end of file diff --git a/defaults/vace_14B_lightning_3p_2_2.json b/defaults/vace_14B_lightning_3p_2_2.json new file mode 100644 index 0000000..fca6cb8 --- /dev/null +++ b/defaults/vace_14B_lightning_3p_2_2.json @@ -0,0 +1,29 @@ +{ + "model": { + "name": "Wan2.2 Vace Lightning 3 Phases 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": ["0;1;0", "0;0;1"], + "lock_guidance_phases": true, + "group": "wan2_2" + }, + "num_inference_steps": 8, + "guidance_phases": 3, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler" +} \ No newline at end of file diff --git a/defaults/vace_standin_14B.json b/defaults/vace_standin_14B.json new file mode 100644 index 0000000..6fc8a68 --- /dev/null +++ b/defaults/vace_standin_14B.json @@ -0,0 +1,9 @@ +{ + "model": { + "name": "Vace Standin 14B", + "architecture": "vace_standin_14B", + "modules": [ "vace_14B", "standin"], + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", + "URLs": "t2v" + } +} \ No newline at end of file diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index c093bd7..5a89d93 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,16 +1,106 @@ # Changelog ## 🔥 Latest News +### July 21 2025: WanGP v7.1 +- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. + +- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment + +- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. + +- LTX IC-Lora support: these are special Loras that consumes a conditional image or video +Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. + +And Also: +- easier way to select video resolution +- started to optimize Matanyone to reduce VRAM requirements + + +### July 15 2025: WanGP v7.0 is an AI Powered Photoshop +This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : +- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB +- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer +- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ... +- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation + +And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\ +As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\ +This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization. + +WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras. + +Also in the news: +- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected. +- *Film Grain* post processing to add a vintage look at your video +- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete +- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated. + + +### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me +Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. + +So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. + +Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. + +The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. + +### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : +**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). + +Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. + +And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. + +As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. + +But wait, there is more: +- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) +- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. + +Also, of interest too: +- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos +- Force the generated video fps to your liking, works wery well with Vace when using a Control Video +- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) + +### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: +- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations +- In one click use the newly generated video as a Control Video or Source Video to be continued +- Manage multiple settings for the same model and switch between them using a dropdown box +- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open +- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) + +Taking care of your life is not enough, you want new stuff to play with ? +- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio +- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best +- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps +- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage +- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video +- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer +- Choice of Wan Samplers / Schedulers +- More Lora formats support + +**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** + +### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ? +- Multithreaded preprocessing when possible for faster generations +- Multithreaded frames Lanczos Upsampling as a bonus +- A new Vace preprocessor : *Flow* to extract fluid motion +- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character. +- Injected Frames Outpainting, in case you missed it in WanGP 6.21 + +Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated. + + ### June 19 2025: WanGP v6.2, Vace even more Powercharged -Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: +👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: - If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time - More processing can combined at the same time (for instance the depth process can be applied outside the mask) - Upgraded the depth extractor to Depth Anything 2 which is much more detailed As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server. - ### June 17 2025: WanGP v6.1, Vace Powercharged -Lots of improvements for Vace the Mother of all Models: +👋 Lots of improvements for Vace the Mother of all Models: - masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask - on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ... - view these modified masks directly inside WanGP during the video generation to check they are really as expected @@ -37,22 +127,6 @@ You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. -### June 12 2025: WanGP v5.6 -👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. - -To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): -- *Fast Hunyuan Video* : generate model t2v in only 6 steps -- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps -- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps - -One more thing... - -The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? - -You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... - -Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. - ### June 11 2025: WanGP v5.5 👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ *Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index 1c9ee6b..7f9dc3f 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -55,16 +55,39 @@ For instance if one adds a module *vace_14B* on top of a model with architecture - *architecture* : architecture Id of the base model of the finetune (see previous section) - *description*: description of the finetune that will appear at the top - *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing. +- *URLs2*: URLs of all the finetune versions (quantized / non quantized) of the weights used for the second phase of a model. For instance with Wan 2.2, the first phase contains the High Noise model weights and the second phase contains the Low Noise model weights. This feature can be used with other models than Wan 2.2 to combine different model weights during the same video generation. - *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. - *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance) --*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above. +-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerators. For instance if you specify here the FusioniX Lora you will be able to reduce the number of generation steps to 10 +-*loras_multipliers* : a list of float numbers or strings that defines the weight of each Lora mentioned in *Loras*. The string syntax is used if you want your lora multiplier to change over the steps (please check the Loras doc) or if you want a multiplier to be applied on a specific High Noise phase or Low Noise phase of a Wan 2.2 model. For instance, here the multiplier will be only applied during the High Noise phase and for half of the steps of this phase the multiplier will be 1 and for the other half 1.1. +``` +"loras" : [ "my_lora.safetensors"], +"loras_multipliers" : [ "1,1.1;0"] +``` + - *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model -*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it. -*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame. -In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse. +In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse. Instead of: +``` + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors" + ], +``` + You can write: +``` + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", +``` -For example let’s say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file. Example of **model** subtree ``` diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index fa4c3a6..9f66422 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -8,9 +8,9 @@ This guide covers installation for different GPU generations and operating syste - Conda or Python venv - Compatible GPU (RTX 10XX or newer recommended) -## Installation for RTX 10XX to RTX 40XX (Stable) +## Installation for RTX 10XX to RTX 50XX (Stable) -This installation uses PyTorch 2.6.0 which is well-tested and stable. +This installation uses PyTorch 2.7.0 which is well-tested and stable. ### Step 1: Download and Setup Environment @@ -27,8 +27,8 @@ conda activate wan2gp ### Step 2: Install PyTorch ```shell -# Install PyTorch 2.6.0 with CUDA 12.4 -pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124 +# Install PyTorch 2.7.0 with CUDA 12.4 +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 ``` ### Step 3: Install Dependencies @@ -40,7 +40,7 @@ pip install -r requirements.txt ### Step 4: Optional Performance Optimizations -#### Sage Attention (30% faster) +#### Sage Attention (30% faster), don't install with RTX 50xx as it is not compatible ```shell # Windows only: Install Triton @@ -58,6 +58,7 @@ pip install triton-windows pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl # Linux (manual compilation required) +python -m pip install "setuptools<=75.8.2" --force-reinstall git clone https://github.com/thu-ml/SageAttention cd SageAttention pip install -e . @@ -70,61 +71,7 @@ pip install -e . pip install flash-attn==2.7.2.post1 ``` -## Installation for RTX 50XX (Beta) - -RTX 50XX GPUs require PyTorch 2.7.0 (beta). This version may be less stable. - -⚠️ **Important:** Use Python 3.10 for compatibility with pip wheels. - -### Step 1: Setup Environment - -```shell -# Clone and setup (same as above) -git clone https://github.com/deepbeepmeep/Wan2GP.git -cd Wan2GP -conda create -n wan2gp python=3.10.9 -conda activate wan2gp -``` - -### Step 2: Install PyTorch Beta - -```shell -# Install PyTorch 2.7.0 with CUDA 12.8 -pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 -``` - -### Step 3: Install Dependencies - -```shell -pip install -r requirements.txt -``` - -### Step 4: Optional Optimizations for RTX 50XX - -#### Sage Attention - -```shell -# Windows -pip install triton-windows -pip install sageattention==1.0.6 - -# Linux -pip install sageattention==1.0.6 -``` - -#### Sage 2 Attention - -```shell -# Windows -pip install triton-windows -pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl - -# Linux (manual compilation) -git clone https://github.com/thu-ml/SageAttention -cd SageAttention -pip install -e . -``` - + ## Attention Modes WanGP supports several attention implementations: @@ -134,6 +81,12 @@ WanGP supports several attention implementations: - **Sage2**: 40% speed boost - **Flash**: Good performance, may be complex to install on Windows +### Attention GPU Compatibility + +- RTX 10XX, 20XX: SDPA +- RTX 30XX, 40XX: SDPA, Flash Attention, Xformers, Sage, Sage2 +- RTX 50XX: SDPA, SDPA, Flash Attention, Xformers, Sage2 + ## Performance Profiles Choose a profile based on your hardware: @@ -161,10 +114,5 @@ If Sage attention doesn't work: - Use Profile 4 for lower VRAM usage - Consider using 1.3B models instead of 14B models -### GPU Compatibility -- RTX 10XX, 20XX: Supported with SDPA attention -- RTX 30XX, 40XX: Full feature support -- RTX 50XX: Beta support with PyTorch 2.7.0 - -For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) \ No newline at end of file +For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) diff --git a/docs/LORAS.md b/docs/LORAS.md index 0b2d034..89b2e59 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -7,18 +7,21 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a Loras are organized in different folders based on the model they're designed for: ### Wan Text-to-Video Models -- `loras/` - General t2v loras +- `loras/` - General t2v loras for Wan 2.1 (t2v only) and for all Wan 2.2 models +Optional sub folders: - `loras/1.3B/` - Loras specifically for 1.3B models +- `loras/5B/` - Loras specifically for 1.3B models - `loras/14B/` - Loras specifically for 14B models ### Wan Image-to-Video Models -- `loras_i2v/` - Image-to-video loras +- `loras_i2v/` - Image-to-video loras for Wan 2.1 ### Other Models - `loras_hunyuan/` - Hunyuan Video t2v loras - `loras_hunyuan_i2v/` - Hunyuan Video i2v loras - `loras_ltxv/` - LTX Video loras - `loras_flux/` - Flux loras +- `loras_qwen/` - Qwen loras ## Custom Lora Directory @@ -40,7 +43,9 @@ python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to 2. Launch WanGP 3. In the Advanced Tab, select the "Loras" section 4. Check the loras you want to activate -5. Set multipliers for each lora (default is 1.0) +5. Set multipliers for each lora (default is 1.0 if multiplier is not mentioned) + +If you store loras in the loras folder once WanGP has been launched, click the *Refresh* button at the top so that it can become selectable. ### Lora Multipliers @@ -53,7 +58,7 @@ Multipliers control the strength of each lora's effect: - First lora: 1.2 strength - Second lora: 0.8 strength -#### Time-based Multipliers +#### Time-based and Phase-based Multipliers For dynamic effects over generation steps, use comma-separated values: ``` 0.9,0.8,0.7 @@ -63,6 +68,55 @@ For dynamic effects over generation steps, use comma-separated values: - First lora: 0.9 → 0.8 → 0.7 - Second lora: 1.2 → 1.1 → 1.0 +With models like Wan 2.2 that uses internally two diffusion models (*High noise* / *Low Noise*) you can specify which Loras you want to be applied for a specific phase by separating each phase with a ";". + +For instance, if you want to disable a lora for phase *High Noise* and enables it only for phase *Low Noise*: +``` +0;1 +``` + +Also with Wan 2.2, if you have two loras and you want the first one to be applied only during the High noise and the second one during the Low noise phase: +``` +1;0 0;1 +``` + +As usual, you can use any float for a multiplier and have a multiplier varries throughout one phase for one Lora: +``` +0.9,0.8;1.2,1.1,1 +``` +In this example multiplier 0.9 and 0.8 will be used during the *High Noise* phase and 1.2, 1.1 and 1 during the *Low Noise* phase. + +Here is another example for two loras: +``` +0.9,0.8;1.2,1.1,1 +0.5;0,0.7 +``` + +If one of several of your Lora multipliers are phased based (that is with a ";") and there are also Loras Multipliers that are only time based (don't have a ";" but have a ",") the time only multiplier will ignore the phases. For instance, let's assume we have a 6 steps denoising process in the following example: + +``` +1;0 +0;1 +0.8,0.7,0.5 +``` +Here the first lora will be as expected only used with the High Noise model and the second lora only used with the Low noise model. However for the third Lora: for steps 1-2 the multiplier will be (regadless of the phase) 0.8 then for steps 3-4 the multiplier will be 0.7 and finally for steps 5-6 the multiplier will be 0.5 + +You can use phased Lora multipliers even if have a single model (that is without any High / Low models) as Lora multiplier phases are aligned with Guidance phases. Let's assume you have defined 3 guidance phases (for instance guidance=3, then guidance=1.5 and at last guidance=1 ): +``` +0;1;0 +0;0;1 +``` +In that case no lora will be applied during the first phase when guidance is 3. Then the fist lora will be only used when guidance is 1.5 and the second lora only when guidance is 1. + +Best of all you can combine 3 guidance phases with High / Low models. Let's take this practical example with *Lightning 4/8 steps loras accelerators for Wan 2.2* where we want to increase the motion by adding some guidance at the very beginning (in that case a first phase that lasts only 1 step should be sufficient): +``` +Guidances: 3.5, 1 and 1 +Model transition: Phase 2-3 +Loras Multipliers: 0;1;0 0;0;1 +``` +Here during the first phase with guidance 3.5, the High model will be used but there won't be any lora at all. Then during phase 2 only the High lora will be used (which requires to set the guidance to 1). At last in phase 3 WanGP will switch to the Low model and then only the Low lora will be used. + +*Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list)* ## Lora Presets Lora Presets are combinations of loras with predefined multipliers and prompts. @@ -100,15 +154,22 @@ WanGP supports multiple lora formats: ## Loras Accelerators Most Loras are used to apply a specific style or to alter the content of the output of the generated video. However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video. +Loras accelerators usually require to the set the Guidance to 1. Don't forget to do it as not only the quality of the generate video will be bad but it will two times slower. -You will find most *Loras Accelerators* here: +You will find most *Loras Accelerators* below: +- Wan 2.1 https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators +- Wan 2.2 +https://huggingface.co/DeepBeepMeep/Wan2.2/tree/main/loras_accelerators +- Qwen: +https://huggingface.co/DeepBeepMeep/Qwen_image/tree/main/loras_accelerators + ### Setup Instructions 1. Download the Lora 2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora -## FusioniX (or FusionX) Lora +## FusioniX (or FusionX) Lora for Wan 2.1 / Wan 2.2 If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v ### Usage @@ -123,8 +184,8 @@ If you need just one Lora accelerator use this one. It is a combination of multi 5. Set generation steps from 8-10 6. Generate! -## Safe-Forcing lightx2v Lora (Video Generation Accelerator) -Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models +## Self-Forcing lightx2v Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 +Selg forcing Lora has been created by Kijai from the Self-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors* ### Usage @@ -140,7 +201,7 @@ You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora 6. Generate! -## CausVid Lora (Video Generation Accelerator) +## CausVid Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement. ### Usage @@ -163,11 +224,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe *Note: Lower steps = lower quality (especially motion)* -## AccVid Lora (Video Generation Accelerator) +## AccVid Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1). - ### Usage 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model 2. Enable Advanced Mode @@ -176,6 +236,21 @@ AccVid is a distilled Wan model that generates videos with a 2x speed improvemen - Set Shift Scale = 5 4. The number steps remain unchanged compared to what you would use with the original model but it will be two times faster since classifier free guidance is not needed +## Lightx2v 4 steps Lora (Video Generation Accelerator) for Wan 2.2 +This lora is in fact composed of two loras, one for the High model and one for the Low Wan 2.2 model. + +You need to select these two loras and set the following Loras multipliers: + +``` +1;0 0;1 (the High lora should be only enabled when only the High model is loaded, same for the Low lora) +``` + +Don't forget to set guidance to 1 ! +## Qwen Image Lightning 4 steps / Lightning 8 steps +Very powerful lora that you can use to reduce the number of steps from 30 to only 4 ! +Just install the lora in *lora_qwen* folder, select the lora and set Guidance to 1 and the number of steps to 4 or 8 + + https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors @@ -190,6 +265,7 @@ https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg - Loras are loaded on-demand to save VRAM - Multiple loras can be used simultaneously - Time-based multipliers don't use extra memory +- The order of Loras doesn't matter (as long as the loras multipliers are in the right order of course !) ## Finding Loras @@ -241,6 +317,7 @@ In the video, a man is presented. The man is in a city and looks at his watch. ## Troubleshooting ### Lora Not Working +0. If it is a lora accelerator, Guidance should be set to 1 1. Check if lora is compatible with your model size (1.3B vs 14B) 2. Verify lora format is supported 3. Try different multiplier values @@ -262,12 +339,13 @@ In the video, a man is presented. The man is in a city and looks at his watch. ```bash # Lora-related command line options ---lora-dir path # Path to t2v loras directory +--lora-dir path # Path to t2v loras directory --lora-dir-i2v path # Path to i2v loras directory --lora-dir-hunyuan path # Path to Hunyuan t2v loras --lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras --lora-dir-ltxv path # Path to LTX Video loras --lora-dir-flux path # Path to Flux loras +--lora-dir-qwen path # Path to Qwen loras --lora-preset preset # Load preset on startup --check-loras # Filter incompatible loras ``` \ No newline at end of file diff --git a/favicon.png b/favicon.png new file mode 100644 index 0000000..30d361d Binary files /dev/null and b/favicon.png differ diff --git a/flux/__init__.py b/flux/__init__.py deleted file mode 100644 index dddc6a3..0000000 --- a/flux/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -try: - from ._version import ( - version as __version__, # type: ignore - version_tuple, - ) -except ImportError: - __version__ = "unknown (no version information available)" - version_tuple = (0, 0, "unknown", "noinfo") - -from pathlib import Path - -PACKAGE = __package__.replace("_", "-") -PACKAGE_ROOT = Path(__file__).parent diff --git a/flux/flux_main.py b/flux/flux_main.py deleted file mode 100644 index 202eb44..0000000 --- a/flux/flux_main.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob -from mmgp import offload as offload -import torch -from wan.utils.utils import calculate_new_dimensions -from flux.sampling import denoise, get_schedule, prepare_kontext, unpack -from flux.modules.layers import get_linear_split_map -from flux.util import ( - aspect_ratio_to_height_width, - load_ae, - load_clip, - load_flow_model, - load_t5, - save_image, -) - -class model_factory: - def __init__( - self, - checkpoint_dir, - model_filename = None, - model_type = None, - base_model_type = None, - text_encoder_filename = None, - quantizeTransformer = False, - save_quantized = False, - dtype = torch.bfloat16, - VAE_dtype = torch.float32, - mixed_precision_transformer = False - ): - self.device = torch.device(f"cuda") - self.VAE_dtype = VAE_dtype - self.dtype = dtype - torch_device = "cpu" - - self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) - self.clip = load_clip(torch_device) - self.name= "flux-dev-kontext" - self.model = load_flow_model(self.name, model_filename[0], torch_device) - - self.vae = load_ae(self.name, device=torch_device) - - # offload.change_dtype(self.model, dtype, True) - if save_quantized: - from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[0], dtype, None) - - split_linear_modules_map = get_linear_split_map() - self.model.split_linear_modules_map = split_linear_modules_map - offload.split_linear_modules(self.model, split_linear_modules_map ) - - - def generate( - self, - seed: int | None = None, - input_prompt: str = "replace the logo with the text 'Black Forest Labs'", - sampling_steps: int = 20, - input_ref_images = None, - width= 832, - height=480, - guide_scale: float = 2.5, - fit_into_canvas = None, - callback = None, - loras_slists = None, - batch_size = 1, - **bbargs - ): - - if self._interrupt: - return None - - device="cuda" - if input_ref_images != None and len(input_ref_images) > 0: - image_ref = input_ref_images[0] - w, h = image_ref.size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - - inp, height, width = prepare_kontext( - t5=self.t5, - clip=self.clip, - prompt=input_prompt, - ae=self.vae, - img_cond=image_ref, - target_width=width, - target_height=height, - bs=batch_size, - seed=seed, - device=device, - ) - - inp.pop("img_cond_orig") - timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) - def unpack_latent(x): - return unpack(x.float(), height, width) - # denoise initial noise - x = denoise(self.model, **inp, timesteps=timesteps, guidance=guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent) - if x==None: return None - # decode latents to pixel space - x = unpack_latent(x) - with torch.autocast(device_type=device, dtype=torch.bfloat16): - x = self.vae.decode(x) - - x = x.clamp(-1, 1) - x = x.transpose(0, 1) - return x - diff --git a/flux/model.py b/flux/model.py deleted file mode 100644 index 1802ae6..0000000 --- a/flux/model.py +++ /dev/null @@ -1,168 +0,0 @@ -from dataclasses import dataclass - -import torch -from torch import Tensor, nn - -from flux.modules.layers import ( - DoubleStreamBlock, - EmbedND, - LastLayer, - MLPEmbedder, - SingleStreamBlock, - timestep_embedding, -) -from flux.modules.lora import LinearLora, replace_linear_with_lora - - -@dataclass -class FluxParams: - in_channels: int - out_channels: int - vec_in_dim: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - depth_single_blocks: int - axes_dim: list[int] - theta: int - qkv_bias: bool - guidance_embed: bool - - -class Flux(nn.Module): - """ - Transformer model for flow matching on sequences. - """ - - def __init__(self, params: FluxParams): - super().__init__() - - self.params = params - self.in_channels = params.in_channels - self.out_channels = params.out_channels - if params.hidden_size % params.num_heads != 0: - raise ValueError( - f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" - ) - pe_dim = params.hidden_size // params.num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) - self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() - ) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) - - self.double_blocks = nn.ModuleList( - [ - DoubleStreamBlock( - self.hidden_size, - self.num_heads, - mlp_ratio=params.mlp_ratio, - qkv_bias=params.qkv_bias, - ) - for _ in range(params.depth) - ] - ) - - self.single_blocks = nn.ModuleList( - [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(params.depth_single_blocks) - ] - ) - - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) - - def preprocess_loras(self, model_type, sd): - new_sd = {} - if len(sd) == 0: return sd - - first_key= next(iter(sd)) - if first_key.startswith("transformer."): - src_list = [".attn.to_q.", ".attn.to_k.", ".attn.to_v."] - tgt_list = [".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v."] - for k,v in sd.items(): - k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks") - k = k.replace("transformer.double_transformer_blocks", "diffusion_model.double_blocks") - for src, tgt in zip(src_list, tgt_list): - k = k.replace(src, tgt) - - new_sd[k] = v - - return new_sd - - def forward( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor | None = None, - callback= None, - pipeline =None, - - ) -> Tensor: - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) - if self.params.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec += self.guidance_in(timestep_embedding(guidance, 256)) - vec += self.vector_in(y) - txt = self.txt_in(txt) - - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - - for block in self.double_blocks: - if callback != None: - callback(-1, None, False, True) - if pipeline._interrupt: - return None - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - - img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) - img = img[:, txt.shape[1] :, ...] - - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - return img - - -class FluxLoraWrapper(Flux): - def __init__( - self, - lora_rank: int = 128, - lora_scale: float = 1.0, - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - - self.lora_rank = lora_rank - - replace_linear_with_lora( - self, - max_rank=lora_rank, - scale=lora_scale, - ) - - def set_lora_scale(self, scale: float) -> None: - for module in self.modules(): - if isinstance(module, LinearLora): - module.set_scale(scale=scale) diff --git a/flux/sampling.py b/flux/sampling.py deleted file mode 100644 index 5c137f1..0000000 --- a/flux/sampling.py +++ /dev/null @@ -1,392 +0,0 @@ -import math -from typing import Callable - -import numpy as np -import torch -from einops import rearrange, repeat -from PIL import Image -from torch import Tensor - -from .model import Flux -from .modules.autoencoder import AutoEncoder -from .modules.conditioner import HFEmbedder -from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder -from .util import PREFERED_KONTEXT_RESOLUTIONS -from einops import rearrange, repeat - - -def get_noise( - num_samples: int, - height: int, - width: int, - device: torch.device, - dtype: torch.dtype, - seed: int, -): - return torch.randn( - num_samples, - 16, - # allow for packing - 2 * math.ceil(height / 16), - 2 * math.ceil(width / 16), - dtype=dtype, - generator=torch.Generator(device=device).manual_seed(seed), - ) - - -def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: - bs, c, h, w = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = repeat(img, "1 ... -> bs ...", bs=bs) - - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - if isinstance(prompt, str): - prompt = [prompt] - txt = t5(prompt) - if txt.shape[0] == 1 and bs > 1: - txt = repeat(txt, "1 ... -> bs ...", bs=bs) - txt_ids = torch.zeros(bs, txt.shape[1], 3) - - vec = clip(prompt) - if vec.shape[0] == 1 and bs > 1: - vec = repeat(vec, "1 ... -> bs ...", bs=bs) - - return { - "img": img, - "img_ids": img_ids.to(img.device), - "txt": txt.to(img.device), - "txt_ids": txt_ids.to(img.device), - "vec": vec.to(img.device), - } - - -def prepare_control( - t5: HFEmbedder, - clip: HFEmbedder, - img: Tensor, - prompt: str | list[str], - ae: AutoEncoder, - encoder: DepthImageEncoder | CannyImageEncoder, - img_cond_path: str, -) -> dict[str, Tensor]: - # load and encode the conditioning image - bs, _, h, w = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img_cond = Image.open(img_cond_path).convert("RGB") - - width = w * 8 - height = h * 8 - img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS) - img_cond = np.array(img_cond) - img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 - img_cond = rearrange(img_cond, "h w c -> 1 c h w") - - with torch.no_grad(): - img_cond = encoder(img_cond) - img_cond = ae.encode(img_cond) - - img_cond = img_cond.to(torch.bfloat16) - img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img_cond.shape[0] == 1 and bs > 1: - img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) - - return_dict = prepare(t5, clip, img, prompt) - return_dict["img_cond"] = img_cond - return return_dict - - -def prepare_fill( - t5: HFEmbedder, - clip: HFEmbedder, - img: Tensor, - prompt: str | list[str], - ae: AutoEncoder, - img_cond_path: str, - mask_path: str, -) -> dict[str, Tensor]: - # load and encode the conditioning image and the mask - bs, _, _, _ = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img_cond = Image.open(img_cond_path).convert("RGB") - img_cond = np.array(img_cond) - img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 - img_cond = rearrange(img_cond, "h w c -> 1 c h w") - - mask = Image.open(mask_path).convert("L") - mask = np.array(mask) - mask = torch.from_numpy(mask).float() / 255.0 - mask = rearrange(mask, "h w -> 1 1 h w") - - with torch.no_grad(): - img_cond = img_cond.to(img.device) - mask = mask.to(img.device) - img_cond = img_cond * (1 - mask) - img_cond = ae.encode(img_cond) - mask = mask[:, 0, :, :] - mask = mask.to(torch.bfloat16) - mask = rearrange( - mask, - "b (h ph) (w pw) -> b (ph pw) h w", - ph=8, - pw=8, - ) - mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if mask.shape[0] == 1 and bs > 1: - mask = repeat(mask, "1 ... -> bs ...", bs=bs) - - img_cond = img_cond.to(torch.bfloat16) - img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img_cond.shape[0] == 1 and bs > 1: - img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) - - img_cond = torch.cat((img_cond, mask), dim=-1) - - return_dict = prepare(t5, clip, img, prompt) - return_dict["img_cond"] = img_cond.to(img.device) - return return_dict - - -def prepare_redux( - t5: HFEmbedder, - clip: HFEmbedder, - img: Tensor, - prompt: str | list[str], - encoder: ReduxImageEncoder, - img_cond_path: str, -) -> dict[str, Tensor]: - bs, _, h, w = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img_cond = Image.open(img_cond_path).convert("RGB") - with torch.no_grad(): - img_cond = encoder(img_cond) - - img_cond = img_cond.to(torch.bfloat16) - if img_cond.shape[0] == 1 and bs > 1: - img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) - - img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = repeat(img, "1 ... -> bs ...", bs=bs) - - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - if isinstance(prompt, str): - prompt = [prompt] - txt = t5(prompt) - txt = torch.cat((txt, img_cond.to(txt)), dim=-2) - if txt.shape[0] == 1 and bs > 1: - txt = repeat(txt, "1 ... -> bs ...", bs=bs) - txt_ids = torch.zeros(bs, txt.shape[1], 3) - - vec = clip(prompt) - if vec.shape[0] == 1 and bs > 1: - vec = repeat(vec, "1 ... -> bs ...", bs=bs) - - return { - "img": img, - "img_ids": img_ids.to(img.device), - "txt": txt.to(img.device), - "txt_ids": txt_ids.to(img.device), - "vec": vec.to(img.device), - } - - -def prepare_kontext( - t5: HFEmbedder, - clip: HFEmbedder, - prompt: str | list[str], - ae: AutoEncoder, - img_cond: str, - seed: int, - device: torch.device, - target_width: int | None = None, - target_height: int | None = None, - bs: int = 1, -) -> tuple[dict[str, Tensor], int, int]: - # load and encode the conditioning image - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - width, height = img_cond.size - aspect_ratio = width / height - - # Kontext is trained on specific resolutions, using one of them is recommended - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) - - width = 2 * int(width / 16) - height = 2 * int(height / 16) - - img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS) - img_cond = np.array(img_cond) - img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 - img_cond = rearrange(img_cond, "h w c -> 1 c h w") - img_cond_orig = img_cond.clone() - - with torch.no_grad(): - img_cond = ae.encode(img_cond.to(device)) - - img_cond = img_cond.to(torch.bfloat16) - img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img_cond.shape[0] == 1 and bs > 1: - img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) - - # image ids are the same as base image with the first dimension set to 1 - # instead of 0 - img_cond_ids = torch.zeros(height // 2, width // 2, 3) - img_cond_ids[..., 0] = 1 - img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] - img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] - img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) - - if target_width is None: - target_width = 8 * width - if target_height is None: - target_height = 8 * height - - img = get_noise( - bs, - target_height, - target_width, - device=device, - dtype=torch.bfloat16, - seed=seed, - ) - - return_dict = prepare(t5, clip, img, prompt) - return_dict["img_cond_seq"] = img_cond - return_dict["img_cond_seq_ids"] = img_cond_ids.to(device) - return_dict["img_cond_orig"] = img_cond_orig - return return_dict, target_height, target_width - - -def time_shift(mu: float, sigma: float, t: Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - -def get_lin_function( - x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 -) -> Callable[[float], float]: - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - -def get_schedule( - num_steps: int, - image_seq_len: int, - base_shift: float = 0.5, - max_shift: float = 1.15, - shift: bool = True, -) -> list[float]: - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) - - # shifting the schedule to favor high timesteps for higher signal images - if shift: - # estimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) - timesteps = time_shift(mu, 1.0, timesteps) - - return timesteps.tolist() - - -def denoise( - model: Flux, - # model input - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - vec: Tensor, - # sampling parameters - timesteps: list[float], - guidance: float = 4.0, - # extra img tokens (channel-wise) - img_cond: Tensor | None = None, - # extra img tokens (sequence-wise) - img_cond_seq: Tensor | None = None, - img_cond_seq_ids: Tensor | None = None, - callback=None, - pipeline=None, - loras_slists=None, - unpack_latent = None, -): - - kwargs = {'pipeline': pipeline, 'callback': callback} - if callback != None: - callback(-1, None, True) - - updated_num_steps= len(timesteps) -1 - if callback != None: - from wgp import update_loras_slists - update_loras_slists(model, loras_slists, updated_num_steps) - callback(-1, None, True, override_num_inference_steps = updated_num_steps) - from mmgp import offload - # this is ignored for schnell - guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - offload.set_step_no_for_lora(model, i) - if pipeline._interrupt: - return None - - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - img_input = img - img_input_ids = img_ids - if img_cond is not None: - img_input = torch.cat((img, img_cond), dim=-1) - if img_cond_seq is not None: - assert ( - img_cond_seq_ids is not None - ), "You need to provide either both or neither of the sequence conditioning" - img_input = torch.cat((img_input, img_cond_seq), dim=1) - img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) - pred = model( - img=img_input, - img_ids=img_input_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, - **kwargs - ) - if pred == None: return None - - if img_input_ids is not None: - pred = pred[:, : img.shape[1]] - - img += (t_prev - t_curr) * pred - if callback is not None: - preview = unpack_latent(img).transpose(0,1) - callback(i, preview, False) - - - return img - - -def unpack(x: Tensor, height: int, width: int) -> Tensor: - return rearrange( - x, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(height / 16), - w=math.ceil(width / 16), - ph=2, - pw=2, - ) diff --git a/flux/to_remove/cli.py b/flux/to_remove/cli.py deleted file mode 100644 index ed0b1c8..0000000 --- a/flux/to_remove/cli.py +++ /dev/null @@ -1,302 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob - -import torch -from fire import Fire -from transformers import pipeline - -from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import ( - check_onnx_access_for_trt, - configs, - load_ae, - load_clip, - load_flow_model, - load_t5, - save_image, -) - -NSFW_THRESHOLD = 0.85 - - -@dataclass -class SamplingOptions: - prompt: str - width: int - height: int - num_steps: int - guidance: float - seed: int | None - - -def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the prompt or write a command starting with a slash:\n" - "- '/w ' will set the width of the generated image\n" - "- '/h ' will set the height of the generated image\n" - "- '/s ' sets the next seed\n" - "- '/g ' sets the guidance (flux-dev only)\n" - "- '/n ' sets the number of steps\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/w"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, width = prompt.split() - options.width = 16 * (int(width) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/h"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, height = prompt.split() - options.height = 16 * (int(height) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/g"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, guidance = prompt.split() - options.guidance = float(guidance) - print(f"Setting guidance to {options.guidance}") - elif prompt.startswith("/s"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, seed = prompt.split() - options.seed = int(seed) - print(f"Setting seed to {options.seed}") - elif prompt.startswith("/n"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, steps = prompt.split() - options.num_steps = int(steps) - print(f"Setting number of steps to {options.num_steps}") - elif prompt.startswith("/q"): - print("Quitting") - return None - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - if prompt != "": - options.prompt = prompt - return options - - -@torch.inference_mode() -def main( - name: str = "flux-schnell", - width: int = 1360, - height: int = 768, - seed: int | None = None, - prompt: str = ( - "a photo of a forest with mist swirling around the tree trunks. The word " - '"FLUX" is painted over it in big, red brush strokes with visible texture' - ), - device: str = "cuda" if torch.cuda.is_available() else "cpu", - num_steps: int | None = None, - loop: bool = False, - guidance: float = 2.5, - offload: bool = False, - output_dir: str = "output", - add_sampling_metadata: bool = True, - trt: bool = False, - trt_transformer_precision: str = "bf16", - track_usage: bool = False, -): - """ - Sample the flux model. Either interactively (set `--loop`) or run for a - single image. - - Args: - name: Name of the model to load - height: height of the sample in pixels (should be a multiple of 16) - width: width of the sample in pixels (should be a multiple of 16) - seed: Set a seed for sampling - output_name: where to save the output image, `{idx}` will be replaced - by the index of the sample - prompt: Prompt used for sampling - device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) - loop: start an interactive session and sample multiple times - guidance: guidance value used for guidance distillation - add_sampling_metadata: Add the prompt to the image Exif metadata - trt: use TensorRT backend for optimized inference - trt_transformer_precision: specify transformer precision for inference - track_usage: track usage of the model for licensing purposes - """ - - prompt = prompt.split("|") - if len(prompt) == 1: - prompt = prompt[0] - additional_prompts = None - else: - additional_prompts = prompt[1:] - prompt = prompt[0] - - assert not ( - (additional_prompts is not None) and loop - ), "Do not provide additional prompts and set loop to True" - - nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) - - if name not in configs: - available = ", ".join(configs.keys()) - raise ValueError(f"Got unknown model name: {name}, chose from {available}") - - torch_device = torch.device(device) - if num_steps is None: - num_steps = 4 if name == "flux-schnell" else 50 - - # allow for packing and conversion to latent space - height = 16 * (height // 16) - width = 16 * (width // 16) - - output_name = os.path.join(output_dir, "img_{idx}.jpg") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - else: - fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] - if len(fns) > 0: - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - if not trt: - t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) - clip = load_clip(torch_device) - model = load_flow_model(name, device="cpu" if offload else torch_device) - ae = load_ae(name, device="cpu" if offload else torch_device) - else: - # lazy import to make install optional - from flux.trt.trt_manager import ModuleName, TRTManager - - # Check if we need ONNX model access (which requires authentication for FLUX models) - onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) - - trt_ctx_manager = TRTManager( - trt_transformer_precision=trt_transformer_precision, - trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"), - ) - engines = trt_ctx_manager.load_engines( - model_name=name, - module_names={ - ModuleName.CLIP, - ModuleName.TRANSFORMER, - ModuleName.T5, - ModuleName.VAE, - }, - engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), - custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), - trt_image_height=height, - trt_image_width=width, - trt_batch_size=1, - trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), - trt_static_batch=False, - trt_static_shape=False, - ) - - ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) - model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) - clip = engines[ModuleName.CLIP].to(torch_device) - t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) - - rng = torch.Generator(device="cpu") - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - ) - - if loop: - opts = parse_prompt(opts) - - while opts is not None: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - t0 = time.perf_counter() - - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=torch_device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - opts.seed = None - if offload: - ae = ae.cpu() - torch.cuda.empty_cache() - t5, clip = t5.to(torch_device), clip.to(torch_device) - inp = prepare(t5, clip, x, prompt=opts.prompt) - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - - # offload TEs to CPU, load model to gpu - if offload: - t5, clip = t5.cpu(), clip.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - x = ae.decode(x) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.perf_counter() - - fn = output_name.format(idx=idx) - print(f"Done in {t1 - t0:.1f}s. Saving {fn}") - - idx = save_image( - nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage - ) - - if loop: - print("-" * 80) - opts = parse_prompt(opts) - elif additional_prompts: - next_prompt = additional_prompts.pop(0) - opts.prompt = next_prompt - else: - opts = None - - if trt: - trt_ctx_manager.stop_runtime() - - -if __name__ == "__main__": - Fire(main) diff --git a/flux/to_remove/cli_control.py b/flux/to_remove/cli_control.py deleted file mode 100644 index 73a6943..0000000 --- a/flux/to_remove/cli_control.py +++ /dev/null @@ -1,390 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob - -import torch -from fire import Fire -from transformers import pipeline - -from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder -from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image - - -@dataclass -class SamplingOptions: - prompt: str - width: int - height: int - num_steps: int - guidance: float - seed: int | None - img_cond_path: str - lora_scale: float | None - - -def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the prompt or write a command starting with a slash:\n" - "- '/w ' will set the width of the generated image\n" - "- '/h ' will set the height of the generated image\n" - "- '/s ' sets the next seed\n" - "- '/g ' sets the guidance (flux-dev only)\n" - "- '/n ' sets the number of steps\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/w"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, width = prompt.split() - options.width = 16 * (int(width) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/h"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, height = prompt.split() - options.height = 16 * (int(height) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/g"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, guidance = prompt.split() - options.guidance = float(guidance) - print(f"Setting guidance to {options.guidance}") - elif prompt.startswith("/s"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, seed = prompt.split() - options.seed = int(seed) - print(f"Setting seed to {options.seed}") - elif prompt.startswith("/n"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, steps = prompt.split() - options.num_steps = int(steps) - print(f"Setting number of steps to {options.num_steps}") - elif prompt.startswith("/q"): - print("Quitting") - return None - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - if prompt != "": - options.prompt = prompt - return options - - -def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: - if options is None: - return None - - user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the conditioning image or write a command starting with a slash:\n" - "- '/q' to quit" - ) - - while True: - img_cond_path = input(user_question) - - if img_cond_path.startswith("/"): - if img_cond_path.startswith("/q"): - print("Quitting") - return None - else: - if not img_cond_path.startswith("/h"): - print(f"Got invalid command '{img_cond_path}'\n{usage}") - print(usage) - continue - - if img_cond_path == "": - break - - if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( - (".jpg", ".jpeg", ".png", ".webp") - ): - print(f"File '{img_cond_path}' does not exist or is not a valid image file") - continue - - options.img_cond_path = img_cond_path - break - - return options - - -def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]: - changed = False - - if options is None: - return None, changed - - user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the lora scale or write a command starting with a slash:\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/q"): - print("Quitting") - return None, changed - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - if prompt != "": - options.lora_scale = float(prompt) - changed = True - return options, changed - - -@torch.inference_mode() -def main( - name: str, - width: int = 1024, - height: int = 1024, - seed: int | None = None, - prompt: str = "a robot made out of gold", - device: str = "cuda" if torch.cuda.is_available() else "cpu", - num_steps: int = 50, - loop: bool = False, - guidance: float | None = None, - offload: bool = False, - output_dir: str = "output", - add_sampling_metadata: bool = True, - img_cond_path: str = "assets/robot.webp", - lora_scale: float | None = 0.85, - trt: bool = False, - trt_transformer_precision: str = "bf16", - track_usage: bool = False, - **kwargs: dict | None, -): - """ - Sample the flux model. Either interactively (set `--loop`) or run for a - single image. - - Args: - height: height of the sample in pixels (should be a multiple of 16) - width: width of the sample in pixels (should be a multiple of 16) - seed: Set a seed for sampling - output_name: where to save the output image, `{idx}` will be replaced - by the index of the sample - prompt: Prompt used for sampling - device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) - loop: start an interactive session and sample multiple times - guidance: guidance value used for guidance distillation - add_sampling_metadata: Add the prompt to the image Exif metadata - img_cond_path: path to conditioning image (jpeg/png/webp) - trt: use TensorRT backend for optimized inference - trt_transformer_precision: specify transformer precision for inference - track_usage: track usage of the model for licensing purposes - """ - nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) - - if "lora" in name: - assert not trt, "TRT does not support LORA" - assert name in [ - "flux-dev-canny", - "flux-dev-depth", - "flux-dev-canny-lora", - "flux-dev-depth-lora", - ], f"Got unknown model name: {name}" - - if guidance is None: - if name in ["flux-dev-canny", "flux-dev-canny-lora"]: - guidance = 30.0 - elif name in ["flux-dev-depth", "flux-dev-depth-lora"]: - guidance = 10.0 - else: - raise NotImplementedError() - - if name not in configs: - available = ", ".join(configs.keys()) - raise ValueError(f"Got unknown model name: {name}, chose from {available}") - - torch_device = torch.device(device) - - output_name = os.path.join(output_dir, "img_{idx}.jpg") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - else: - fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] - if len(fns) > 0: - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - if name in ["flux-dev-depth", "flux-dev-depth-lora"]: - img_embedder = DepthImageEncoder(torch_device) - elif name in ["flux-dev-canny", "flux-dev-canny-lora"]: - img_embedder = CannyImageEncoder(torch_device) - else: - raise NotImplementedError() - - if not trt: - # init all components - t5 = load_t5(torch_device, max_length=512) - clip = load_clip(torch_device) - model = load_flow_model(name, device="cpu" if offload else torch_device) - ae = load_ae(name, device="cpu" if offload else torch_device) - else: - # lazy import to make install optional - from flux.trt.trt_manager import ModuleName, TRTManager - - trt_ctx_manager = TRTManager( - trt_transformer_precision=trt_transformer_precision, - trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), - ) - - engines = trt_ctx_manager.load_engines( - model_name=name, - module_names={ - ModuleName.CLIP, - ModuleName.TRANSFORMER, - ModuleName.T5, - ModuleName.VAE, - ModuleName.VAE_ENCODER, - }, - engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), - custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""), - trt_image_height=height, - trt_image_width=width, - trt_batch_size=1, - trt_static_batch=kwargs.get("static_batch", True), - trt_static_shape=kwargs.get("static_shape", True), - ) - - ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) - model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) - clip = engines[ModuleName.CLIP].to(torch_device) - t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) - - # set lora scale - if "lora" in name and lora_scale is not None: - for _, module in model.named_modules(): - if hasattr(module, "set_scale"): - module.set_scale(lora_scale) - - rng = torch.Generator(device="cpu") - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - img_cond_path=img_cond_path, - lora_scale=lora_scale, - ) - - if loop: - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - if "lora" in name: - opts, changed = parse_lora_scale(opts) - if changed: - # update the lora scale: - for _, module in model.named_modules(): - if hasattr(module, "set_scale"): - module.set_scale(opts.lora_scale) - - while opts is not None: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - t0 = time.perf_counter() - - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=torch_device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - opts.seed = None - if offload: - t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) - inp = prepare_control( - t5, - clip, - x, - prompt=opts.prompt, - ae=ae, - encoder=img_embedder, - img_cond_path=opts.img_cond_path, - ) - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - - # offload TEs and AE to CPU, load model to gpu - if offload: - t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - x = ae.decode(x) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.perf_counter() - print(f"Done in {t1 - t0:.1f}s") - - idx = save_image( - nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage - ) - - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - if "lora" in name: - opts, changed = parse_lora_scale(opts) - if changed: - # update the lora scale: - for _, module in model.named_modules(): - if hasattr(module, "set_scale"): - module.set_scale(opts.lora_scale) - else: - opts = None - - if trt: - trt_ctx_manager.stop_runtime() - - -if __name__ == "__main__": - Fire(main) diff --git a/flux/to_remove/cli_fill.py b/flux/to_remove/cli_fill.py deleted file mode 100644 index ab78c50..0000000 --- a/flux/to_remove/cli_fill.py +++ /dev/null @@ -1,334 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob - -import torch -from fire import Fire -from PIL import Image -from transformers import pipeline - -from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image - - -@dataclass -class SamplingOptions: - prompt: str - width: int - height: int - num_steps: int - guidance: float - seed: int | None - img_cond_path: str - img_mask_path: str - - -def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the prompt or write a command starting with a slash:\n" - "- '/s ' sets the next seed\n" - "- '/g ' sets the guidance (flux-dev only)\n" - "- '/n ' sets the number of steps\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/g"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, guidance = prompt.split() - options.guidance = float(guidance) - print(f"Setting guidance to {options.guidance}") - elif prompt.startswith("/s"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, seed = prompt.split() - options.seed = int(seed) - print(f"Setting seed to {options.seed}") - elif prompt.startswith("/n"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, steps = prompt.split() - options.num_steps = int(steps) - print(f"Setting number of steps to {options.num_steps}") - elif prompt.startswith("/q"): - print("Quitting") - return None - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - if prompt != "": - options.prompt = prompt - return options - - -def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: - if options is None: - return None - - user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the conditioning image or write a command starting with a slash:\n" - "- '/q' to quit" - ) - - while True: - img_cond_path = input(user_question) - - if img_cond_path.startswith("/"): - if img_cond_path.startswith("/q"): - print("Quitting") - return None - else: - if not img_cond_path.startswith("/h"): - print(f"Got invalid command '{img_cond_path}'\n{usage}") - print(usage) - continue - - if img_cond_path == "": - break - - if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( - (".jpg", ".jpeg", ".png", ".webp") - ): - print(f"File '{img_cond_path}' does not exist or is not a valid image file") - continue - else: - with Image.open(img_cond_path) as img: - width, height = img.size - - if width % 32 != 0 or height % 32 != 0: - print(f"Image dimensions must be divisible by 32, got {width}x{height}") - continue - - options.img_cond_path = img_cond_path - break - - return options - - -def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None: - if options is None: - return None - - user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the conditioning mask or write a command starting with a slash:\n" - "- '/q' to quit" - ) - - while True: - img_mask_path = input(user_question) - - if img_mask_path.startswith("/"): - if img_mask_path.startswith("/q"): - print("Quitting") - return None - else: - if not img_mask_path.startswith("/h"): - print(f"Got invalid command '{img_mask_path}'\n{usage}") - print(usage) - continue - - if img_mask_path == "": - break - - if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith( - (".jpg", ".jpeg", ".png", ".webp") - ): - print(f"File '{img_mask_path}' does not exist or is not a valid image file") - continue - else: - with Image.open(img_mask_path) as img: - width, height = img.size - - if width % 32 != 0 or height % 32 != 0: - print(f"Image dimensions must be divisible by 32, got {width}x{height}") - continue - else: - with Image.open(options.img_cond_path) as img_cond: - img_cond_width, img_cond_height = img_cond.size - - if width != img_cond_width or height != img_cond_height: - print( - f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}" - ) - continue - - options.img_mask_path = img_mask_path - break - - return options - - -@torch.inference_mode() -def main( - seed: int | None = None, - prompt: str = "a white paper cup", - device: str = "cuda" if torch.cuda.is_available() else "cpu", - num_steps: int = 50, - loop: bool = False, - guidance: float = 30.0, - offload: bool = False, - output_dir: str = "output", - add_sampling_metadata: bool = True, - img_cond_path: str = "assets/cup.png", - img_mask_path: str = "assets/cup_mask.png", - track_usage: bool = False, -): - """ - Sample the flux model. Either interactively (set `--loop`) or run for a - single image. This demo assumes that the conditioning image and mask have - the same shape and that height and width are divisible by 32. - - Args: - seed: Set a seed for sampling - output_name: where to save the output image, `{idx}` will be replaced - by the index of the sample - prompt: Prompt used for sampling - device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) - loop: start an interactive session and sample multiple times - guidance: guidance value used for guidance distillation - add_sampling_metadata: Add the prompt to the image Exif metadata - img_cond_path: path to conditioning image (jpeg/png/webp) - img_mask_path: path to conditioning mask (jpeg/png/webp) - track_usage: track usage of the model for licensing purposes - """ - nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) - - name = "flux-dev-fill" - if name not in configs: - available = ", ".join(configs.keys()) - raise ValueError(f"Got unknown model name: {name}, chose from {available}") - - torch_device = torch.device(device) - - output_name = os.path.join(output_dir, "img_{idx}.jpg") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - else: - fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] - if len(fns) > 0: - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - # init all components - t5 = load_t5(torch_device, max_length=128) - clip = load_clip(torch_device) - model = load_flow_model(name, device="cpu" if offload else torch_device) - ae = load_ae(name, device="cpu" if offload else torch_device) - - rng = torch.Generator(device="cpu") - with Image.open(img_cond_path) as img: - width, height = img.size - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - img_cond_path=img_cond_path, - img_mask_path=img_mask_path, - ) - - if loop: - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - - with Image.open(opts.img_cond_path) as img: - width, height = img.size - opts.height = height - opts.width = width - - opts = parse_img_mask_path(opts) - - while opts is not None: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - t0 = time.perf_counter() - - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=torch_device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - opts.seed = None - if offload: - t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) - inp = prepare_fill( - t5, - clip, - x, - prompt=opts.prompt, - ae=ae, - img_cond_path=opts.img_cond_path, - mask_path=opts.img_mask_path, - ) - - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - - # offload TEs and AE to CPU, load model to gpu - if offload: - t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - x = ae.decode(x) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.perf_counter() - print(f"Done in {t1 - t0:.1f}s") - - idx = save_image( - nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage - ) - - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - - with Image.open(opts.img_cond_path) as img: - width, height = img.size - opts.height = height - opts.width = width - - opts = parse_img_mask_path(opts) - else: - opts = None - - -if __name__ == "__main__": - Fire(main) diff --git a/flux/to_remove/cli_kontext.py b/flux/to_remove/cli_kontext.py deleted file mode 100644 index 17ad6a1..0000000 --- a/flux/to_remove/cli_kontext.py +++ /dev/null @@ -1,368 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob - -import torch -from fire import Fire - -from flux.content_filters import PixtralContentFilter -from flux.sampling import denoise, get_schedule, prepare_kontext, unpack -from flux.util import ( - aspect_ratio_to_height_width, - check_onnx_access_for_trt, - load_ae, - load_clip, - load_flow_model, - load_t5, - save_image, -) - - -@dataclass -class SamplingOptions: - prompt: str - width: int | None - height: int | None - num_steps: int - guidance: float - seed: int | None - img_cond_path: str - - -def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the prompt or write a command starting with a slash:\n" - "- '/ar :' will set the aspect ratio of the generated image\n" - "- '/s ' sets the next seed\n" - "- '/g ' sets the guidance (flux-dev only)\n" - "- '/n ' sets the number of steps\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/ar"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, ratio_prompt = prompt.split() - if ratio_prompt == "auto": - options.width = None - options.height = None - print("Setting resolution to input image resolution.") - else: - options.width, options.height = aspect_ratio_to_height_width(ratio_prompt) - print(f"Setting resolution to {options.width} x {options.height}.") - elif prompt.startswith("/h"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, height = prompt.split() - if height == "auto": - options.height = None - else: - options.height = 16 * (int(height) // 16) - if options.height is not None and options.width is not None: - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - else: - print(f"Setting resolution to {options.width} x {options.height}.") - elif prompt.startswith("/g"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, guidance = prompt.split() - options.guidance = float(guidance) - print(f"Setting guidance to {options.guidance}") - elif prompt.startswith("/s"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, seed = prompt.split() - options.seed = int(seed) - print(f"Setting seed to {options.seed}") - elif prompt.startswith("/n"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, steps = prompt.split() - options.num_steps = int(steps) - print(f"Setting number of steps to {options.num_steps}") - elif prompt.startswith("/q"): - print("Quitting") - return None - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - if prompt != "": - options.prompt = prompt - return options - - -def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: - if options is None: - return None - - user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write a path to an image directly, leave this field empty " - "to repeat the last input image or write a command starting with a slash:\n" - "- '/q' to quit\n\n" - "The input image will be edited by FLUX.1 Kontext creating a new image based" - "on your instruction prompt." - ) - - while True: - img_cond_path = input(user_question) - - if img_cond_path.startswith("/"): - if img_cond_path.startswith("/q"): - print("Quitting") - return None - else: - if not img_cond_path.startswith("/h"): - print(f"Got invalid command '{img_cond_path}'\n{usage}") - print(usage) - continue - - if img_cond_path == "": - break - - if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( - (".jpg", ".jpeg", ".png", ".webp") - ): - print(f"File '{img_cond_path}' does not exist or is not a valid image file") - continue - - options.img_cond_path = img_cond_path - break - - return options - - -@torch.inference_mode() -def main( - name: str = "flux-dev-kontext", - aspect_ratio: str | None = None, - seed: int | None = None, - prompt: str = "replace the logo with the text 'Black Forest Labs'", - device: str = "cuda" if torch.cuda.is_available() else "cpu", - num_steps: int = 30, - loop: bool = False, - guidance: float = 2.5, - offload: bool = False, - output_dir: str = "output", - add_sampling_metadata: bool = True, - img_cond_path: str = "assets/cup.png", - trt: bool = False, - trt_transformer_precision: str = "bf16", - track_usage: bool = False, -): - """ - Sample the flux model. Either interactively (set `--loop`) or run for a - single image. - - Args: - height: height of the sample in pixels (should be a multiple of 16), None - defaults to the size of the conditioning - width: width of the sample in pixels (should be a multiple of 16), None - defaults to the size of the conditioning - seed: Set a seed for sampling - output_name: where to save the output image, `{idx}` will be replaced - by the index of the sample - prompt: Prompt used for sampling - device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) - loop: start an interactive session and sample multiple times - guidance: guidance value used for guidance distillation - add_sampling_metadata: Add the prompt to the image Exif metadata - img_cond_path: path to conditioning image (jpeg/png/webp) - trt: use TensorRT backend for optimized inference - track_usage: track usage of the model for licensing purposes - """ - assert name == "flux-dev-kontext", f"Got unknown model name: {name}" - - torch_device = torch.device(device) - - output_name = os.path.join(output_dir, "img_{idx}.jpg") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - else: - fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] - if len(fns) > 0: - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - if aspect_ratio is None: - width = None - height = None - else: - width, height = aspect_ratio_to_height_width(aspect_ratio) - - if not trt: - t5 = load_t5(torch_device, max_length=512) - clip = load_clip(torch_device) - model = load_flow_model(name, device="cpu" if offload else torch_device) - else: - # lazy import to make install optional - from flux.trt.trt_manager import ModuleName, TRTManager - - # Check if we need ONNX model access (which requires authentication for FLUX models) - onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) - - trt_ctx_manager = TRTManager( - trt_transformer_precision=trt_transformer_precision, - trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), - ) - engines = trt_ctx_manager.load_engines( - model_name=name, - module_names={ - ModuleName.CLIP, - ModuleName.TRANSFORMER, - ModuleName.T5, - }, - engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), - custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), - trt_image_height=height, - trt_image_width=width, - trt_batch_size=1, - trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), - trt_static_batch=False, - trt_static_shape=False, - ) - - model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) - clip = engines[ModuleName.CLIP].to(torch_device) - t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) - - ae = load_ae(name, device="cpu" if offload else torch_device) - content_filter = PixtralContentFilter(torch.device("cpu")) - - rng = torch.Generator(device="cpu") - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - img_cond_path=img_cond_path, - ) - - if loop: - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - - while opts is not None: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - t0 = time.perf_counter() - - if content_filter.test_txt(opts.prompt): - print("Your prompt has been automatically flagged. Please choose another prompt.") - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - else: - opts = None - continue - if content_filter.test_image(opts.img_cond_path): - print("Your input image has been automatically flagged. Please choose another image.") - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - else: - opts = None - continue - - if offload: - t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) - inp, height, width = prepare_kontext( - t5=t5, - clip=clip, - prompt=opts.prompt, - ae=ae, - img_cond_path=opts.img_cond_path, - target_width=opts.width, - target_height=opts.height, - bs=1, - seed=opts.seed, - device=torch_device, - ) - from safetensors.torch import save_file - - save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft") - inp.pop("img_cond_orig") - opts.seed = None - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - - # offload TEs and AE to CPU, load model to gpu - if offload: - t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - t00 = time.time() - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - torch.cuda.synchronize() - t01 = time.time() - print(f"Denoising took {t01 - t00:.3f}s") - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), height, width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - ae_dev_t0 = time.perf_counter() - x = ae.decode(x) - torch.cuda.synchronize() - ae_dev_t1 = time.perf_counter() - print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s") - - if content_filter.test_image(x.cpu()): - print( - "Your output image has been automatically flagged. Choose another prompt/image or try again." - ) - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - else: - opts = None - continue - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.perf_counter() - print(f"Done in {t1 - t0:.1f}s") - - idx = save_image( - None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage - ) - - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - else: - opts = None - - -if __name__ == "__main__": - Fire(main) diff --git a/flux/to_remove/cli_redux.py b/flux/to_remove/cli_redux.py deleted file mode 100644 index 71e59e1..0000000 --- a/flux/to_remove/cli_redux.py +++ /dev/null @@ -1,290 +0,0 @@ -import os -import re -import time -from dataclasses import dataclass -from glob import iglob - -import torch -from fire import Fire -from transformers import pipeline - -from flux.modules.image_embedders import ReduxImageEncoder -from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack -from flux.util import ( - get_checkpoint_path, - load_ae, - load_clip, - load_flow_model, - load_t5, - save_image, -) - - -@dataclass -class SamplingOptions: - prompt: str - width: int - height: int - num_steps: int - guidance: float - seed: int | None - img_cond_path: str - - -def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = "Write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Leave this field empty to do nothing " - "or write a command starting with a slash:\n" - "- '/w ' will set the width of the generated image\n" - "- '/h ' will set the height of the generated image\n" - "- '/s ' sets the next seed\n" - "- '/g ' sets the guidance (flux-dev only)\n" - "- '/n ' sets the number of steps\n" - "- '/q' to quit" - ) - - while (prompt := input(user_question)).startswith("/"): - if prompt.startswith("/w"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, width = prompt.split() - options.width = 16 * (int(width) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/h"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, height = prompt.split() - options.height = 16 * (int(height) // 16) - print( - f"Setting resolution to {options.width} x {options.height} " - f"({options.height * options.width / 1e6:.2f}MP)" - ) - elif prompt.startswith("/g"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, guidance = prompt.split() - options.guidance = float(guidance) - print(f"Setting guidance to {options.guidance}") - elif prompt.startswith("/s"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, seed = prompt.split() - options.seed = int(seed) - print(f"Setting seed to {options.seed}") - elif prompt.startswith("/n"): - if prompt.count(" ") != 1: - print(f"Got invalid command '{prompt}'\n{usage}") - continue - _, steps = prompt.split() - options.num_steps = int(steps) - print(f"Setting number of steps to {options.num_steps}") - elif prompt.startswith("/q"): - print("Quitting") - return None - else: - if not prompt.startswith("/h"): - print(f"Got invalid command '{prompt}'\n{usage}") - print(usage) - return options - - -def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: - if options is None: - return None - - user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" - usage = ( - "Usage: Either write your prompt directly, leave this field empty " - "to repeat the conditioning image or write a command starting with a slash:\n" - "- '/q' to quit" - ) - - while True: - img_cond_path = input(user_question) - - if img_cond_path.startswith("/"): - if img_cond_path.startswith("/q"): - print("Quitting") - return None - else: - if not img_cond_path.startswith("/h"): - print(f"Got invalid command '{img_cond_path}'\n{usage}") - print(usage) - continue - - if img_cond_path == "": - break - - if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( - (".jpg", ".jpeg", ".png", ".webp") - ): - print(f"File '{img_cond_path}' does not exist or is not a valid image file") - continue - - options.img_cond_path = img_cond_path - break - - return options - - -@torch.inference_mode() -def main( - name: str = "flux-dev", - width: int = 1360, - height: int = 768, - seed: int | None = None, - device: str = "cuda" if torch.cuda.is_available() else "cpu", - num_steps: int | None = None, - loop: bool = False, - guidance: float = 2.5, - offload: bool = False, - output_dir: str = "output", - add_sampling_metadata: bool = True, - img_cond_path: str = "assets/robot.webp", - track_usage: bool = False, -): - """ - Sample the flux model. Either interactively (set `--loop`) or run for a - single image. - - Args: - name: Name of the base model to use (either 'flux-dev' or 'flux-schnell') - height: height of the sample in pixels (should be a multiple of 16) - width: width of the sample in pixels (should be a multiple of 16) - seed: Set a seed for sampling - device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) - loop: start an interactive session and sample multiple times - guidance: guidance value used for guidance distillation - offload: offload models to CPU when not in use - output_dir: where to save the output images - add_sampling_metadata: Add the prompt to the image Exif metadata - img_cond_path: path to conditioning image (jpeg/png/webp) - track_usage: track usage of the model for licensing purposes - """ - - nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) - - if name not in (available := ["flux-dev", "flux-schnell"]): - raise ValueError(f"Got unknown model name: {name}, chose from {available}") - - torch_device = torch.device(device) - if num_steps is None: - num_steps = 4 if name == "flux-schnell" else 50 - - output_name = os.path.join(output_dir, "img_{idx}.jpg") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - else: - fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] - if len(fns) > 0: - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - # init all components - t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) - clip = load_clip(torch_device) - model = load_flow_model(name, device="cpu" if offload else torch_device) - ae = load_ae(name, device="cpu" if offload else torch_device) - - # Download and initialize the Redux adapter - redux_path = str( - get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX") - ) - img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path) - - rng = torch.Generator(device="cpu") - prompt = "" - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - img_cond_path=img_cond_path, - ) - - if loop: - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - - while opts is not None: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - t0 = time.perf_counter() - - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=torch_device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - opts.seed = None - if offload: - ae = ae.cpu() - torch.cuda.empty_cache() - t5, clip = t5.to(torch_device), clip.to(torch_device) - inp = prepare_redux( - t5, - clip, - x, - prompt=opts.prompt, - encoder=img_embedder, - img_cond_path=opts.img_cond_path, - ) - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - - # offload TEs to CPU, load model to gpu - if offload: - t5, clip = t5.cpu(), clip.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - x = ae.decode(x) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t1 = time.perf_counter() - print(f"Done in {t1 - t0:.1f}s") - - idx = save_image( - nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage - ) - - if loop: - print("-" * 80) - opts = parse_prompt(opts) - opts = parse_img_cond_path(opts) - else: - opts = None - - -if __name__ == "__main__": - Fire(main) diff --git a/flux/to_remove/content_filters.py b/flux/to_remove/content_filters.py deleted file mode 100644 index 8de89ed..0000000 --- a/flux/to_remove/content_filters.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -from einops import rearrange -from PIL import Image -from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline - -PROMPT_IMAGE_INTEGRITY = """ -Task: Analyze an image to identify potential copyright concerns or depictions of public figures. - -Output: Respond with only "yes" or "no" - -Criteria for "yes": -- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.) -- The image displays a trademarked logo or brand -- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.) - -Criteria for "no": -- All other cases -- When you cannot identify the specific copyrighted work or named individual - -Critical Requirements: -1. You must be able to name the exact copyrighted work or specific person depicted -2. General references to demographics or characteristics are not sufficient -3. Base your decision solely on visual content, not interpretation -4. Provide only the one-word answer: "yes" or "no" -""".strip() - - -PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?" - -PROMPT_TEXT_INTEGRITY = """ -Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures. - -Output: Respond with only "yes" or "no" - -Criteria for "Yes": -- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.) -- The prompt explicitly mentions a trademarked logo or brand -- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.) - -Criteria for "No": -- All other cases -- When you cannot identify the specific copyrighted work or named individual - -Critical Requirements: -1. You must be able to name the exact copyrighted work or specific person referenced -2. General demographic descriptions or characteristics are not sufficient -3. Analyze only the prompt text, not potential image outcomes -4. Provide only the one-word answer: "yes" or "no" - -The prompt to check is: ------ -{prompt} ------ - -Does this prompt have copyright concerns or includes public figures? -""".strip() - - -class PixtralContentFilter(torch.nn.Module): - def __init__( - self, - device: torch.device = torch.device("cpu"), - nsfw_threshold: float = 0.85, - ): - super().__init__() - - model_id = "mistral-community/pixtral-12b" - self.processor = AutoProcessor.from_pretrained(model_id) - self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device) - - self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"]) - - self.nsfw_classifier = pipeline( - "image-classification", model="Falconsai/nsfw_image_detection", device=device - ) - self.nsfw_threshold = nsfw_threshold - - def yes_no_logit_processor( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - """ - Sets all tokens but yes/no to the minimum. - """ - scores_yes_token = scores[:, self.yes_token].clone() - scores_no_token = scores[:, self.no_token].clone() - scores_min = scores.min() - scores[:, :] = scores_min - 1 - scores[:, self.yes_token] = scores_yes_token - scores[:, self.no_token] = scores_no_token - return scores - - def test_image(self, image: Image.Image | str | torch.Tensor) -> bool: - if isinstance(image, torch.Tensor): - image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") - image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) - elif isinstance(image, str): - image = Image.open(image) - - classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw") - if classification["score"] > self.nsfw_threshold: - return True - - # 512^2 pixels are enough for checking - w, h = image.size - f = (512**2 / (w * h)) ** 0.5 - image = image.resize((int(f * w), int(f * h))) - - chat = [ - { - "role": "user", - "content": [ - { - "type": "text", - "content": PROMPT_IMAGE_INTEGRITY, - }, - { - "type": "image", - "image": image, - }, - { - "type": "text", - "content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, - }, - ], - } - ] - - inputs = self.processor.apply_chat_template( - chat, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ).to(self.model.device) - - generate_ids = self.model.generate( - **inputs, - max_new_tokens=1, - logits_processor=[self.yes_no_logit_processor], - do_sample=False, - ) - return generate_ids[0, -1].item() == self.yes_token - - def test_txt(self, txt: str) -> bool: - chat = [ - { - "role": "user", - "content": [ - { - "type": "text", - "content": PROMPT_TEXT_INTEGRITY.format(prompt=txt), - }, - ], - } - ] - - inputs = self.processor.apply_chat_template( - chat, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ).to(self.model.device) - - generate_ids = self.model.generate( - **inputs, - max_new_tokens=1, - logits_processor=[self.yes_no_logit_processor], - do_sample=False, - ) - return generate_ids[0, -1].item() == self.yes_token diff --git a/i2v_inference.py b/i2v_inference.py deleted file mode 100644 index f833868..0000000 --- a/i2v_inference.py +++ /dev/null @@ -1,682 +0,0 @@ -import os -import time -import argparse -import json -import torch -import traceback -import gc -import random - -# These imports rely on your existing code structure -# They must match the location of your WAN code, etc. -import wan -from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS -from wan.modules.attention import get_attention_modes -from wan.utils.utils import cache_video -from mmgp import offload, safetensors2, profile_type - -try: - import triton -except ImportError: - pass - -DATA_DIR = "ckpts" - -# -------------------------------------------------- -# HELPER FUNCTIONS -# -------------------------------------------------- - -def sanitize_file_name(file_name): - """Clean up file name from special chars.""" - return ( - file_name.replace("/", "") - .replace("\\", "") - .replace(":", "") - .replace("|", "") - .replace("?", "") - .replace("<", "") - .replace(">", "") - .replace('"', "") - ) - -def extract_preset(lset_name, lora_dir, loras): - """ - Load a .lset JSON that lists the LoRA files to apply, plus multipliers - and possibly a suggested prompt prefix. - """ - lset_name = sanitize_file_name(lset_name) - if not lset_name.endswith(".lset"): - lset_name_filename = os.path.join(lora_dir, lset_name + ".lset") - else: - lset_name_filename = os.path.join(lora_dir, lset_name) - - if not os.path.isfile(lset_name_filename): - raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}") - - with open(lset_name_filename, "r", encoding="utf-8") as reader: - text = reader.read() - lset = json.loads(text) - - loras_choices_files = lset["loras"] - loras_choices = [] - missing_loras = [] - for lora_file in loras_choices_files: - # Build absolute path and see if it is in loras - full_lora_path = os.path.join(lora_dir, lora_file) - if full_lora_path in loras: - idx = loras.index(full_lora_path) - loras_choices.append(str(idx)) - else: - missing_loras.append(lora_file) - - if len(missing_loras) > 0: - missing_list = ", ".join(missing_loras) - raise ValueError(f"Missing LoRA files for preset: {missing_list}") - - loras_mult_choices = lset["loras_mult"] - prompt_prefix = lset.get("prompt", "") - full_prompt = lset.get("full_prompt", False) - return loras_choices, loras_mult_choices, prompt_prefix, full_prompt - -def get_attention_mode(args_attention, installed_modes): - """ - Decide which attention mode to use: either the user choice or auto fallback. - """ - if args_attention == "auto": - for candidate in ["sage2", "sage", "sdpa"]: - if candidate in installed_modes: - return candidate - return "sdpa" # last fallback - elif args_attention in installed_modes: - return args_attention - else: - raise ValueError( - f"Requested attention mode '{args_attention}' not installed. " - f"Installed modes: {installed_modes}" - ) - -def load_i2v_model(model_filename, text_encoder_filename, is_720p): - """ - Load the i2v model with a specific size config and text encoder. - """ - if is_720p: - print("Loading 14B-720p i2v model ...") - cfg = WAN_CONFIGS['i2v-14B'] - wan_model = wan.WanI2V( - config=cfg, - checkpoint_dir=DATA_DIR, - model_filename=model_filename, - text_encoder_filename=text_encoder_filename - ) - else: - print("Loading 14B-480p i2v model ...") - cfg = WAN_CONFIGS['i2v-14B'] - wan_model = wan.WanI2V( - config=cfg, - checkpoint_dir=DATA_DIR, - model_filename=model_filename, - text_encoder_filename=text_encoder_filename - ) - # Pipe structure - pipe = { - "transformer": wan_model.model, - "text_encoder": wan_model.text_encoder.model, - "text_encoder_2": wan_model.clip.model, - "vae": wan_model.vae.model - } - return wan_model, pipe - -def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps): - """ - Load loras from a directory, optionally apply a preset. - """ - from pathlib import Path - import glob - - if not lora_dir or not Path(lora_dir).is_dir(): - print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.") - return [], [], [], "", "", False - - # Gather LoRA files - loras = sorted( - glob.glob(os.path.join(lora_dir, "*.sft")) - + glob.glob(os.path.join(lora_dir, "*.safetensors")) - ) - loras_names = [Path(x).stem for x in loras] - - # Offload them with no activation - offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False) - - # If user gave a preset, apply it - default_loras_choices = [] - default_loras_multis_str = "" - default_prompt_prefix = "" - preset_applied_full_prompt = False - if lora_preset: - loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras) - default_loras_choices = loras_choices - # If user stored loras_mult as a list or string in JSON, unify that to str - if isinstance(loras_mult, list): - # Just store them in a single line - default_loras_multis_str = " ".join([str(x) for x in loras_mult]) - else: - default_loras_multis_str = str(loras_mult) - default_prompt_prefix = prefix - preset_applied_full_prompt = full_prompt - - return ( - loras, - loras_names, - default_loras_choices, - default_loras_multis_str, - default_prompt_prefix, - preset_applied_full_prompt - ) - -def parse_loras_and_activate( - transformer, - loras, - loras_choices, - loras_mult_str, - num_inference_steps -): - """ - Activate the chosen LoRAs with multipliers over the pipeline's transformer. - Supports stepwise expansions (like "0.5,0.8" for partial steps). - """ - if not loras or not loras_choices: - # no LoRAs selected - return - - # Handle multipliers - def is_float_or_comma_list(x): - """ - Example: "0.5", or "0.8,1.0", etc. is valid. - """ - if not x: - return False - for chunk in x.split(","): - try: - float(chunk.strip()) - except ValueError: - return False - return True - - # Convert multiline or spaced lines to a single list - lines = [ - line.strip() - for line in loras_mult_str.replace("\r", "\n").split("\n") - if line.strip() and not line.strip().startswith("#") - ] - # Now combine them by space - joined_line = " ".join(lines) # "1.0 2.0,3.0" - if not joined_line.strip(): - multipliers = [] - else: - multipliers = joined_line.split(" ") - - # Expand each item - final_multipliers = [] - for mult in multipliers: - mult = mult.strip() - if not mult: - continue - if is_float_or_comma_list(mult): - # Could be "0.7" or "0.5,0.6" - if "," in mult: - # expand over steps - chunk_vals = [float(x.strip()) for x in mult.split(",")] - expanded = expand_list_over_steps(chunk_vals, num_inference_steps) - final_multipliers.append(expanded) - else: - final_multipliers.append(float(mult)) - else: - raise ValueError(f"Invalid LoRA multiplier: '{mult}'") - - # If fewer multipliers than chosen LoRAs => pad with 1.0 - needed = len(loras_choices) - len(final_multipliers) - if needed > 0: - final_multipliers += [1.0]*needed - - # Actually activate them - offload.activate_loras(transformer, loras_choices, final_multipliers) - -def expand_list_over_steps(short_list, num_steps): - """ - If user gave (0.5, 0.8) for example, expand them over `num_steps`. - The expansion is simply linear slice across steps. - """ - result = [] - inc = len(short_list) / float(num_steps) - idxf = 0.0 - for _ in range(num_steps): - value = short_list[int(idxf)] - result.append(value) - idxf += inc - return result - -def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR): - """ - Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'. - If not, downloads them from a Hugging Face Hub repo. - Adjust the 'repo_id' and needed files as appropriate. - """ - import os - from pathlib import Path - - try: - from huggingface_hub import hf_hub_download, snapshot_download - except ImportError as e: - raise ImportError( - "huggingface_hub is required for automatic model download. " - "Please install it via `pip install huggingface_hub`." - ) from e - - # Identify just the filename portion for each path - def basename(path_str): - return os.path.basename(path_str) - - repo_id = "DeepBeepMeep/Wan2.1" - target_root = local_folder - - # You can customize this list as needed for i2v usage. - # At minimum you need: - # 1) The requested i2v transformer file - # 2) The requested text encoder file - # 3) VAE file - # 4) The open-clip xlm-roberta-large weights - # - # If your i2v config references additional files, add them here. - needed_files = [ - "Wan2.1_VAE.pth", - "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - basename(text_encoder_filename), - basename(transformer_filename_i2v), - ] - - # The original script also downloads an entire "xlm-roberta-large" folder - # via snapshot_download. If you require that for your pipeline, - # you can add it here, for example: - subfolder_name = "xlm-roberta-large" - if not Path(os.path.join(target_root, subfolder_name)).exists(): - snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root) - - for filename in needed_files: - local_path = os.path.join(target_root, filename) - if not os.path.isfile(local_path): - print(f"File '{filename}' not found locally. Downloading from {repo_id} ...") - hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir=target_root - ) - else: - # Already present - pass - - print("All required i2v files are present.") - - -# -------------------------------------------------- -# ARGUMENT PARSER -# -------------------------------------------------- - -def parse_args(): - parser = argparse.ArgumentParser( - description="Image-to-Video inference using WAN 2.1 i2v" - ) - # Model + Tools - parser.add_argument( - "--quantize-transformer", - action="store_true", - help="Use on-the-fly transformer quantization" - ) - parser.add_argument( - "--compile", - action="store_true", - help="Enable PyTorch 2.0 compile for the transformer" - ) - parser.add_argument( - "--attention", - type=str, - default="auto", - help="Which attention to use: auto, sdpa, sage, sage2, flash" - ) - parser.add_argument( - "--profile", - type=int, - default=4, - help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM" - ) - parser.add_argument( - "--preload", - type=int, - default=0, - help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)" - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="Verbosity level [0..5]" - ) - - # i2v Model - parser.add_argument( - "--transformer-file", - type=str, - default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors", - help="Which i2v model to load" - ) - parser.add_argument( - "--text-encoder-file", - type=str, - default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors", - help="Which text encoder to use" - ) - - # LoRA - parser.add_argument( - "--lora-dir", - type=str, - default="", - help="Path to a directory containing i2v LoRAs" - ) - parser.add_argument( - "--lora-preset", - type=str, - default="", - help="A .lset preset name in the lora_dir to auto-apply" - ) - - # Generation Options - parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation") - parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt") - parser.add_argument("--resolution", type=str, default="832x480", help="WxH") - parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.") - parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.") - parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale") - parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.") - parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos") - parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.") - parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]") - parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.") - parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance") - parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG") - parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG") - - # LoRA usage - parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.") - parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.") - - # Input - parser.add_argument( - "--input-image", - type=str, - default=None, - required=True, - help="Path to an input image (or multiple)." - ) - parser.add_argument( - "--output-file", - type=str, - default="output.mp4", - help="Where to save the resulting video." - ) - - return parser.parse_args() - -# -------------------------------------------------- -# MAIN -# -------------------------------------------------- - -def main(): - args = parse_args() - - # Setup environment - offload.default_verboseLevel = args.verbose - installed_attn_modes = get_attention_modes() - - # Decide attention - chosen_attention = get_attention_mode(args.attention, installed_attn_modes) - offload.shared_state["_attention"] = chosen_attention - - # Determine i2v resolution format - if "720" in args.transformer_file: - is_720p = True - else: - is_720p = False - - # Make sure we have the needed models locally - download_models_if_needed(args.transformer_file, args.text_encoder_file) - - # Load i2v - wan_model, pipe = load_i2v_model( - model_filename=args.transformer_file, - text_encoder_filename=args.text_encoder_file, - is_720p=is_720p - ) - wan_model._interrupt = False - - # Offload / profile - # e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...) - # pass the budgets if you want, etc. - kwargs = {} - if args.profile == 2 or args.profile == 4: - # preload is in MB - if args.preload == 0: - budgets = {"transformer": 100, "text_encoder": 100, "*": 1000} - else: - budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000} - kwargs["budgets"] = budgets - elif args.profile == 3: - kwargs["budgets"] = {"*": "70%"} - - compile_choice = "transformer" if args.compile else "" - # Create the offload object - offloadobj = offload.profile( - pipe, - profile_no=args.profile, - compile=compile_choice, - quantizeTransformer=args.quantize_transformer, - **kwargs - ) - - # If user wants to use LoRAs - ( - loras, - loras_names, - default_loras_choices, - default_loras_multis_str, - preset_prompt_prefix, - preset_full_prompt - ) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps) - - # Combine user prompt with preset prompt if the preset indicates so - if preset_prompt_prefix: - if preset_full_prompt: - # Full override - user_prompt = preset_prompt_prefix - else: - # Just prefix - user_prompt = preset_prompt_prefix + "\n" + args.prompt - else: - user_prompt = args.prompt - - # Actually parse user LoRA choices if they did not rely purely on the preset - if args.loras_choices: - # If user gave e.g. "0,1", we treat that as new additions - lora_choice_list = [x.strip() for x in args.loras_choices.split(",")] - else: - # Use the defaults from the preset - lora_choice_list = default_loras_choices - - # Activate them - parse_loras_and_activate( - pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps - ) - - # Negative prompt - negative_prompt = args.negative_prompt or "" - - # Sanity check resolution - if "*" in args.resolution.lower(): - print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.") - resolution_str = args.resolution.lower().replace("*", "x") - else: - resolution_str = args.resolution - - try: - width, height = [int(x) for x in resolution_str.split("x")] - except: - raise ValueError(f"Invalid resolution: '{resolution_str}'") - - # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided) - if args.slg_layers: - slg_list = [int(x) for x in args.slg_layers.split(",")] - else: - slg_list = None - - # Additional checks (from your original code). - if "480p" in args.transformer_file: - # Then we cannot exceed certain area for 480p model - if width * height > 832*480: - raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.") - # etc. - - # Handle random seed - if args.seed < 0: - args.seed = random.randint(0, 999999999) - print(f"Using seed={args.seed}") - - # Setup tea cache if needed - trans = wan_model.model - trans.enable_cache = (args.teacache > 0) - if trans.enable_cache: - if "480p" in args.transformer_file: - # example from your code - trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - elif "720p" in args.transformer_file: - trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] - else: - raise ValueError("Teacache not supported for this model variant") - - # Attempt generation - print("Starting generation ...") - start_time = time.time() - - # Read the input image - if not os.path.isfile(args.input_image): - raise ValueError(f"Input image does not exist: {args.input_image}") - - from PIL import Image - input_img = Image.open(args.input_image).convert("RGB") - - # Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration - - # Define the generation call - # - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ... - # You can correct to that if needed: - frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1 - # RIFLEx - enable_riflex = args.riflex - - # If teacache => reset counters - if trans.enable_cache: - trans.teacache_counter = 0 - trans.cache_multiplier = args.teacache - trans.cache_start_step = int(args.teacache_start * args.steps / 100.0) - trans.num_steps = args.steps - trans.cache_skipped_steps = 0 - trans.previous_residual_uncond = None - trans.previous_residual_cond = None - - # VAE Tiling - device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 - if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM - use_vae_config = 1 - elif device_mem_capacity >= 8000: - use_vae_config = 2 - else: - use_vae_config = 3 - - if use_vae_config == 1: - VAE_tile_size = 0 - elif use_vae_config == 2: - VAE_tile_size = 256 - else: - VAE_tile_size = 128 - - print('Using VAE tile size of', VAE_tile_size) - - # Actually run the i2v generation - try: - sample_frames = wan_model.generate( - input_prompt = user_prompt, - image_start = input_img, - frame_num=frame_count, - width=width, - height=height, - # max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom - shift=args.flow_shift, - sampling_steps=args.steps, - guide_scale=args.guidance_scale, - n_prompt=negative_prompt, - seed=args.seed, - offload_model=False, - callback=None, # or define your own callback if you want - enable_RIFLEx=enable_riflex, - VAE_tile_size=VAE_tile_size, - joint_pass=slg_list is None, # set if you want a small speed improvement without SLG - slg_layers=slg_list, - slg_start=args.slg_start, - slg_end=args.slg_end, - ) - except Exception as e: - offloadobj.unload_all() - gc.collect() - torch.cuda.empty_cache() - - err_str = f"Generation failed with error: {e}" - # Attempt to detect OOM errors - s = str(e).lower() - if any(keyword in s for keyword in ["memory", "cuda", "alloc"]): - raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str) - else: - traceback.print_exc() - raise RuntimeError(err_str) - - # After generation - offloadobj.unload_all() - gc.collect() - torch.cuda.empty_cache() - - if sample_frames is None: - raise RuntimeError("No frames were returned (maybe generation was aborted or failed).") - - # If teacache was used, we can see how many steps were skipped - if trans.enable_cache: - print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}") - - # Save result - sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W] - os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) - - # Use the provided helper from your code to store the MP4 - # By default, you used cache_video(tensor=..., save_file=..., fps=16, ...) - # or you can do your own. We'll do the same for consistency: - cache_video( - tensor=sample_frames[None], # shape => [1, c, T, H, W] - save_file=args.output_file, - fps=16, - nrow=1, - normalize=True, - value_range=(-1, 1) - ) - - end_time = time.time() - elapsed_s = end_time - start_time - print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.") - -if __name__ == "__main__": - main() diff --git a/loras_qwen/Readme.txt b/loras_qwen/Readme.txt new file mode 100644 index 0000000..14a70a8 --- /dev/null +++ b/loras_qwen/Readme.txt @@ -0,0 +1 @@ +LTX Video loras \ No newline at end of file diff --git a/hyvideo/__init__.py b/models/__init__.py similarity index 100% rename from hyvideo/__init__.py rename to models/__init__.py diff --git a/models/flux/__init__.py b/models/flux/__init__.py new file mode 100644 index 0000000..d0a07ae --- /dev/null +++ b/models/flux/__init__.py @@ -0,0 +1,2 @@ +from .flux_main import model_factory +from . import flux_handler diff --git a/flux/__main__.py b/models/flux/__main__.py similarity index 100% rename from flux/__main__.py rename to models/flux/__main__.py diff --git a/flux/_version.py b/models/flux/_version.py similarity index 100% rename from flux/_version.py rename to models/flux/_version.py diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py new file mode 100644 index 0000000..162ec4c --- /dev/null +++ b/models/flux/flux_handler.py @@ -0,0 +1,121 @@ +import torch + +def get_ltxv_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + flux_model = model_def.get("flux-model", "flux-dev") + flux_schnell = flux_model == "flux-schnell" + flux_chroma = flux_model == "flux-chroma" + flux_uso = flux_model == "flux-dev-uso" + model_def_output = { + "image_outputs" : True, + "no_negative_prompt" : not flux_chroma, + } + if flux_chroma: + model_def_output["guidance_max_phases"] = 1 + elif not flux_schnell: + model_def_output["embedded_guidance"] = True + if flux_uso : + model_def_output["any_image_refs_relative_size"] = True + model_def_output["no_background_removal"] = True + + model_def_output["image_ref_choices"] = { + "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"), + ("Up to two Images are Style Images", "IJ")], + "default": "I", + "letters_filter": "IJ", + "label": "Reference Images / Style Images" + } + + return model_def_output + + @staticmethod + def query_supported_types(): + return ["flux"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux") + return latent_rgb_factors, latent_rgb_factors_bias + + + @staticmethod + def query_model_family(): + return "flux" + + @staticmethod + def query_family_infos(): + return {"flux":(30, "Flux 1")} + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) + return [ + { + "repoId" : "DeepBeepMeep/Flux", + "sourceFolderList" : ["siglip-so400m-patch14-384", "",], + "fileList" : [ ["config.json", "preprocessor_config.json", "model.safetensors"], ["flux_vae.safetensors"] ] + }, + { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : ["T5_xxl_1.1"], + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ] + }, + { + "repoId" : "DeepBeepMeep/HunyuanVideo", + "sourceFolderList" : [ "clip_vit_large_patch14", ], + "fileList" :[ + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ] + } + ] + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .flux_main import model_factory + + flux_model = model_factory( + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5} + + if flux_model.vision_encoder is not None: + pipe["siglip_model"] = flux_model.vision_encoder + if flux_model.feature_embedder is not None: + pipe["feature_embedder"] = flux_model.feature_embedder + return flux_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + flux_model = model_def.get("flux-model", "flux-dev") + flux_uso = flux_model == "flux-dev-uso" + ui_defaults.update({ + "embedded_guidance": 2.5, + }) + if model_def.get("reference_image", False): + ui_defaults.update({ + "video_prompt_type": "I" if flux_uso else "KI", + }) + diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py new file mode 100644 index 0000000..55a2b91 --- /dev/null +++ b/models/flux/flux_main.py @@ -0,0 +1,221 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob +from mmgp import offload as offload +import torch +from shared.utils.utils import calculate_new_dimensions +from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack +from .modules.layers import get_linear_split_map +from transformers import SiglipVisionModel, SiglipImageProcessor + +from .util import ( + aspect_ratio_to_height_width, + load_ae, + load_clip, + load_flow_model, + load_t5, + save_image, +) + +from PIL import Image + +def stitch_images(img1, img2): + # Resize img2 to match img1's height + width1, height1 = img1.size + width2, height2 = img2.size + new_width2 = int(width2 * height1 / height2) + img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) + + stitched = Image.new('RGB', (width1 + new_width2, height1)) + stitched.paste(img1, (0, 0)) + stitched.paste(img2_resized, (width1, 0)) + return stitched + +class model_factory: + def __init__( + self, + checkpoint_dir, + model_filename = None, + model_type = None, + model_def = None, + base_model_type = None, + text_encoder_filename = None, + quantizeTransformer = False, + save_quantized = False, + dtype = torch.bfloat16, + VAE_dtype = torch.float32, + mixed_precision_transformer = False + ): + self.device = torch.device(f"cuda") + self.VAE_dtype = VAE_dtype + self.dtype = dtype + torch_device = "cpu" + self.guidance_max_phases = model_def.get("guidance_max_phases", 0) + + # model_filename = ["c:/temp/flux1-schnell.safetensors"] + + self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) + self.clip = load_clip(torch_device) + self.name = model_def.get("flux-model", "flux-dev") + # self.name= "flux-dev-kontext" + # self.name= "flux-dev" + # self.name= "flux-schnell" + source = model_def.get("source", None) + self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) + + self.vae = load_ae(self.name, device=torch_device) + + siglip_processor = siglip_model = feature_embedder = None + if self.name == 'flux-dev-uso': + siglip_path = "ckpts/siglip-so400m-patch14-384" + siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) + siglip_model = SiglipVisionModel.from_pretrained(siglip_path) + siglip_model.eval().to("cpu") + if len(model_filename) > 1: + from .modules.layers import SigLIPMultiFeatProjModel + feature_embedder = SigLIPMultiFeatProjModel( + siglip_token_nums=729, + style_token_nums=64, + siglip_token_dims=1152, + hidden_size=3072, #self.hidden_size, + context_layer_norm=True, + ) + offload.load_model_data(feature_embedder, model_filename[1]) + self.vision_encoder = siglip_model + self.vision_encoder_processor = siglip_processor + self.feature_embedder = feature_embedder + + # offload.change_dtype(self.model, dtype, True) + # offload.save_model(self.model, "flux-dev.safetensors") + + if not source is None: + from wgp import save_model + save_model(self.model, model_type, dtype, None) + + if save_quantized: + from wgp import save_quantized_model + save_quantized_model(self.model, model_type, model_filename[0], dtype, None) + + split_linear_modules_map = get_linear_split_map() + self.model.split_linear_modules_map = split_linear_modules_map + offload.split_linear_modules(self.model, split_linear_modules_map ) + + + def generate( + self, + seed: int | None = None, + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + n_prompt: str = None, + sampling_steps: int = 20, + input_ref_images = None, + width= 832, + height=480, + embedded_guidance_scale: float = 2.5, + guide_scale = 2.5, + fit_into_canvas = None, + callback = None, + loras_slists = None, + batch_size = 1, + video_prompt_type = "", + joint_pass = False, + image_refs_relative_size = 100, + **bbargs + ): + if self._interrupt: + return None + if self.guidance_max_phases < 1: guide_scale = 1 + if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" + device="cuda" + flux_dev_uso = self.name in ['flux-dev-uso'] + image_stiching = not self.name in ['flux-dev-uso'] + + input_ref_images = [] if input_ref_images is None else input_ref_images[:] + ref_style_imgs = [] + + if "I" in video_prompt_type and len(input_ref_images) > 0: + if flux_dev_uso : + if "J" in video_prompt_type: + ref_style_imgs = input_ref_images + input_ref_images = [] + elif len(input_ref_images) > 1 : + ref_style_imgs = input_ref_images[-1:] + input_ref_images = input_ref_images[:-1] + if image_stiching: + # image stiching method + stiched = input_ref_images[0] + if "K" in video_prompt_type : + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + + for new_img in input_ref_images[1:]: + stiched = stitch_images(stiched, new_img) + input_ref_images = [stiched] + else: + first_ref = 0 + if "K" in video_prompt_type: + # image latents tiling method + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) + first_ref = 1 + + for i in range(first_ref,len(input_ref_images)): + w, h = input_ref_images[i].size + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) + input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + else: + input_ref_images = None + + if flux_dev_uso : + inp, height, width = prepare_multi_ip( + ae=self.vae, + img_cond_list=input_ref_images, + target_width=width, + target_height=height, + bs=batch_size, + seed=seed, + device=device, + ) + else: + inp, height, width = prepare_kontext( + ae=self.vae, + img_cond_list=input_ref_images, + target_width=width, + target_height=height, + bs=batch_size, + seed=seed, + device=device, + ) + + inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) + if guide_scale != 1: + inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device)) + + timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) + + ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] + if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None: + # processing style feat into textural hidden space + siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs] + siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) + siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device) + inp["siglip_embedding"] = siglip_embedding + inp["siglip_embedding_ids"] = siglip_embedding_ids + + def unpack_latent(x): + return unpack(x.float(), height, width) + + # denoise initial noise + x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass) + if x==None: return None + # decode latents to pixel space + x = unpack_latent(x) + with torch.autocast(device_type=device, dtype=torch.bfloat16): + x = self.vae.decode(x) + + x = x.clamp(-1, 1) + x = x.transpose(0, 1) + return x + diff --git a/flux/math.py b/models/flux/math.py similarity index 97% rename from flux/math.py rename to models/flux/math.py index 9e8aa59..a249f19 100644 --- a/flux/math.py +++ b/models/flux/math.py @@ -1,7 +1,7 @@ import torch from einops import rearrange from torch import Tensor -from wan.modules.attention import pay_attention +from shared.attention import pay_attention def attention(qkv_list, pe: Tensor) -> Tensor: diff --git a/models/flux/model.py b/models/flux/model.py new file mode 100644 index 0000000..c4642d0 --- /dev/null +++ b/models/flux/model.py @@ -0,0 +1,296 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from .modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, + DistilledGuidance, + ChromaModulationOut, + SigLIPMultiFeatProjModel, +) +from .modules.lora import LinearLora, replace_linear_with_lora + + +@dataclass +class FluxParams: + in_channels: int + out_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + chroma: bool = False + eso: bool = False + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): + # This function slices up the modulations tensor which has the following layout: + # single : num_single_blocks * 3 elements + # double_img : num_double_blocks * 6 elements + # double_txt : num_double_blocks * 6 elements + # final : 2 elements + if block_type == "final": + return (tensor[:, -2:-1, :], tensor[:, -1:, :]) + single_block_count = self.params.depth_single_blocks + double_block_count = self.params.depth + offset = 3 * idx + if block_type == "single": + return ChromaModulationOut.from_offset(tensor, offset) + # Double block modulations are 6 elements so we double 3 * idx. + offset *= 2 + if block_type in {"double_img", "double_txt"}: + # Advance past the single block modulations. + offset += 3 * single_block_count + if block_type == "double_txt": + # Advance past the double block img modulations. + offset += 6 * double_block_count + return ( + ChromaModulationOut.from_offset(tensor, offset), + ChromaModulationOut.from_offset(tensor, offset + 3), + ) + raise ValueError("Bad block_type") + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = params.out_channels + self.chroma = params.chroma + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + if self.chroma: + self.distilled_guidance_layer = DistilledGuidance( + in_dim=64, + hidden_dim=5120, + out_dim=3072, + n_layers=5, + ) + else: + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + chroma_modulation = self.chroma, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, chroma_modulation = self.chroma) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, chroma_modulation = self.chroma) + + def preprocess_loras(self, model_type, sd): + new_sd = {} + if len(sd) == 0: return sd + + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + first_key= next(iter(sd)) + if first_key.startswith("lora_unet_"): + new_sd = {} + print("Converting Lora Safetensors format to Lora Diffusers format") + repl_list = ["linear1", "linear2", "modulation", "img_attn", "txt_attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"] + src_list = ["_" + k + "." for k in repl_list] + src_list2 = ["_" + k + "_" for k in repl_list] + tgt_list = ["." + k + "." for k in repl_list] + + for k,v in sd.items(): + k = k.replace("lora_unet_blocks_","diffusion_model.blocks.") + k = k.replace("lora_unet__blocks_","diffusion_model.blocks.") + k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.") + k = k.replace("lora_unet_double_blocks_","diffusion_model.double_blocks.") + + for s,s2, t in zip(src_list, src_list2, tgt_list): + k = k.replace(s,t) + k = k.replace(s2,t) + + k = k.replace("lora_up","lora_B") + k = k.replace("lora_down","lora_A") + + new_sd[k] = v + + elif first_key.startswith("transformer."): + root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2", + "time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2", + "x_embedder", "context_embedder", "proj_out" ] + + root_tgt = ["time_in.in_layer", "time_in.out_layer", "vector_in.in_layer", "vector_in.out_layer", + "guidance_in.in_layer", "guidance_in.out_layer", + "img_in", "txt_in", "final_layer.linear" ] + + double_src = ["norm1.linear", "norm1_context.linear", "attn.norm_q", "attn.norm_k", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "attn.to_out.0" ,"attn.to_add_out", "attn.to_out", ".attn.to_", ".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj.", ] + double_tgt = ["img_mod.lin", "txt_mod.lin", "img_attn.norm.query_norm", "img_attn.norm.key_norm", "img_mlp.0", "img_mlp.2", "txt_mlp.0", "txt_mlp.2", "img_attn.proj", "txt_attn.proj", "img_attn.proj", ".img_attn.", ".txt_attn.q.", ".txt_attn.k.", ".txt_attn.v."] + + single_src = ["norm.linear", "attn.norm_q", "attn.norm_k", "proj_out",".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."] + single_tgt = ["modulation.lin","norm.query_norm", "norm.key_norm", "linear2", ".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v.", ".linear1_mlp."] + + + for k,v in sd.items(): + if k.startswith("transformer.single_transformer_blocks"): + k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks") + for src, tgt in zip(single_src, single_tgt): + k = k.replace(src, tgt) + elif k.startswith("transformer.transformer_blocks"): + k = k.replace("transformer.transformer_blocks", "diffusion_model.double_blocks") + for src, tgt in zip(double_src, double_tgt): + k = k.replace(src, tgt) + else: + k = k.replace("transformer.", "diffusion_model.") + for src, tgt in zip(root_src, root_tgt): + k = k.replace(src, tgt) + + if "norm_out.linear" in k: + if "lora_B" in k: + v = swap_scale_shift(v) + k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1") + new_sd[k] = v + else: + new_sd = sd + return new_sd + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt_list, + txt_ids_list, + timesteps: Tensor, + y_list, + img_len = 0, + guidance: Tensor | None = None, + callback= None, + pipeline =None, + siglip_embedding = None, + siglip_embedding_ids = None, + ) -> Tensor: + + sz = len(txt_list) + # running on sequences img + img = self.img_in(img) + img_list = [img] if sz==1 else [img, img.clone()] + + if self.chroma: + mod_index_length = 344 + distill_timestep = timestep_embedding(timesteps, 16).to(img.device, img.dtype) + guidance = torch.tensor([0.]* distill_timestep.shape[0]) + distil_guidance = timestep_embedding(guidance, 16).to(img.device, img.dtype) + modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) + mod_vectors = self.distilled_guidance_layer(input_vec) + else: + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec += self.guidance_in(timestep_embedding(guidance, 256)) + vec_list = [ vec + self.vector_in(y) for y in y_list] + + img = None + txt_list = [self.txt_in(txt) for txt in txt_list ] + if siglip_embedding is not None: + txt_list = [torch.cat((siglip_embedding, txt) , dim=1) for txt in txt_list] + txt_ids_list = [torch.cat((siglip_embedding_ids, txt_id) , dim=1) for txt_id in txt_ids_list] + + pe_list = [self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) for txt_ids in txt_ids_list] + + for i, block in enumerate(self.double_blocks): + if self.chroma: vec_list = [( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))] * sz + if callback != None: + callback(-1, None, False, True) + if pipeline._interrupt: + return [None] * sz + for img, txt, pe, vec in zip(img_list, txt_list, pe_list, vec_list): + img[...], txt[...] = block(img=img, txt=txt, vec=vec, pe=pe) + img = txt = pe = vec= None + + img_list = [torch.cat((txt, img), 1) for txt, img in zip(txt_list, img_list)] + + for i, block in enumerate(self.single_blocks): + if self.chroma: vec_list= [self.get_modulations(mod_vectors, "single", idx=i)] * sz + if callback != None: + callback(-1, None, False, True) + if pipeline._interrupt: + return [None] * sz + for img, pe, vec in zip(img_list, pe_list, vec_list): + img[...]= block(x=img, vec=vec, pe=pe) + img = pe = vec = None + img_list = [ img[:, txt.shape[1] : txt.shape[1] + img_len, ...] for img, txt in zip(img_list, txt_list)] + + if self.chroma: vec_list = [self.get_modulations(mod_vectors, "final")] * sz + out_list = [] + for i, (img, vec) in enumerate(zip(img_list, vec_list)): + out_list.append( self.final_layer(img, vec)) # (N, T, patch_size ** 2 * out_channels) + img_list[i] = img = vec = None + return out_list + + +class FluxLoraWrapper(Flux): + def __init__( + self, + lora_rank: int = 128, + lora_scale: float = 1.0, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.lora_rank = lora_rank + + replace_linear_with_lora( + self, + max_rank=lora_rank, + scale=lora_scale, + ) + + def set_lora_scale(self, scale: float) -> None: + for module in self.modules(): + if isinstance(module, LinearLora): + module.set_scale(scale=scale) diff --git a/flux/modules/autoencoder.py b/models/flux/modules/autoencoder.py similarity index 100% rename from flux/modules/autoencoder.py rename to models/flux/modules/autoencoder.py diff --git a/flux/modules/conditioner.py b/models/flux/modules/conditioner.py similarity index 100% rename from flux/modules/conditioner.py rename to models/flux/modules/conditioner.py diff --git a/flux/modules/image_embedders.py b/models/flux/modules/image_embedders.py similarity index 98% rename from flux/modules/image_embedders.py rename to models/flux/modules/image_embedders.py index aa26d9b..011f840 100644 --- a/flux/modules/image_embedders.py +++ b/models/flux/modules/image_embedders.py @@ -7,7 +7,7 @@ from safetensors.torch import load_file as load_sft from torch import nn from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel -from flux.util import print_load_warning +from ..util import print_load_warning class DepthImageEncoder: diff --git a/flux/modules/layers copy.py b/models/flux/modules/layers copy.py similarity index 100% rename from flux/modules/layers copy.py rename to models/flux/modules/layers copy.py diff --git a/flux/modules/layers.py b/models/flux/modules/layers.py similarity index 60% rename from flux/modules/layers.py rename to models/flux/modules/layers.py index 0fbe404..8cd981d 100644 --- a/flux/modules/layers.py +++ b/models/flux/modules/layers.py @@ -5,13 +5,14 @@ import torch from einops import rearrange from torch import Tensor, nn -from flux.math import attention, rope +from ..math import attention, rope def get_linear_split_map(): hidden_size = 3072 split_linear_modules_map = { "qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]}, - "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]} + "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}, + "linear1_qkv" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]}, } return split_linear_modules_map @@ -116,10 +117,20 @@ class ModulationOut: scale: Tensor gate: Tensor +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0): + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], + ) -def split_mlp(mlp, x, divide = 4): + +def split_mlp(mlp, x, divide = 8): x_shape = x.shape x = x.view(-1, x.shape[-1]) + chunk_size = int(x.shape[0]/divide) chunk_size = int(x_shape[1]/divide) x_chunks = torch.split(x, chunk_size) for i, x_chunk in enumerate(x_chunks): @@ -145,13 +156,15 @@ class Modulation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, chroma_modulation = False): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True) + self.chroma_modulation = chroma_modulation + if not chroma_modulation: + self.img_mod = Modulation(hidden_size, double=True) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) @@ -162,7 +175,8 @@ class DoubleStreamBlock(nn.Module): nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) - self.txt_mod = Modulation(hidden_size, double=True) + if not chroma_modulation: + self.txt_mod = Modulation(hidden_size, double=True) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) @@ -174,8 +188,11 @@ class DoubleStreamBlock(nn.Module): ) def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) + if self.chroma_modulation: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + else: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) @@ -249,10 +266,12 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None, + chroma_modulation = False, ): super().__init__() self.hidden_dim = hidden_size self.num_heads = num_heads + self.chroma_modulation = chroma_modulation head_dim = hidden_size // num_heads self.scale = qk_scale or head_dim**-0.5 @@ -268,10 +287,14 @@ class SingleStreamBlock(nn.Module): self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False) + if not chroma_modulation: + self.modulation = Modulation(hidden_size, double=False) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: - mod, _ = self.modulation(vec) + if self.chroma_modulation: + mod = vec + else: + mod, _ = self.modulation(vec) x_mod = self.pre_norm(x) x_mod.mul_(1 + mod.scale) x_mod.add_(mod.shift) @@ -315,14 +338,172 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, chroma_modulation = False): super().__init__() + self.chroma_modulation = chroma_modulation self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + if not chroma_modulation: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x: Tensor, vec: Tensor) -> Tensor: - shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + if self.chroma_modulation: + shift, scale = vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + else: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + # x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x)) x = self.linear(x) return x + + +class DistilledGuidance(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range( n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range( n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 257, + style_token_nums: int = 256, + siglip_token_dims: int = 1536, + hidden_size: int = 3072, + context_layer_norm: bool = False, + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + nn.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-11], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs.hidden_states[-20], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layers + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding diff --git a/flux/modules/lora.py b/models/flux/modules/lora.py similarity index 100% rename from flux/modules/lora.py rename to models/flux/modules/lora.py diff --git a/models/flux/sampling.py b/models/flux/sampling.py new file mode 100644 index 0000000..5534e9f --- /dev/null +++ b/models/flux/sampling.py @@ -0,0 +1,429 @@ +import math +from typing import Callable + +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image +from torch import Tensor + +from .model import Flux +from .modules.autoencoder import AutoEncoder +from .modules.conditioner import HFEmbedder +from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder +from .util import PREFERED_KONTEXT_RESOLUTIONS +from einops import rearrange, repeat +from typing import Literal +import torchvision.transforms.functional as TVF + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + dtype=dtype, + device=device, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]: + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "neg_txt" if neg else "txt": txt.to(device), + "neg_txt_ids" if neg else "txt_ids": txt_ids.to(device), + "neg_vec" if neg else "vec": vec.to(device), + } + + +def prepare_img( img: Tensor) -> dict[str, Tensor]: + bs, c, h, w = img.shape + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + } + + + + + +def prepare_redux( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + encoder: ReduxImageEncoder, + img_cond_path: str, +) -> dict[str, Tensor]: + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + with torch.no_grad(): + img_cond = encoder(img_cond) + + img_cond = img_cond.to(torch.bfloat16) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + txt = torch.cat((txt, img_cond.to(txt)), dim=-2) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def prepare_kontext( + ae: AutoEncoder, + img_cond_list: list, + seed: int, + device: torch.device, + target_width: int | None = None, + target_height: int | None = None, + bs: int = 1, + +) -> tuple[dict[str, Tensor], int, int]: + # load and encode the conditioning image + + img_cond_seq = None + img_cond_seq_ids = None + if img_cond_list == None: img_cond_list = [] + height_offset = 0 + width_offset = 0 + for cond_no, img_cond in enumerate(img_cond_list): + width, height = img_cond.size + aspect_ratio = width / height + + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + + width = 2 * int(width / 16) + height = 2 * int(height / 16) + + img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS) + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + with torch.no_grad(): + img_cond_latents = ae.encode(img_cond.to(device)) + + img_cond_latents = img_cond_latents.to(torch.bfloat16) + img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs) + img_cond = None + + # image ids are the same as base image with the first dimension set to 1 + # instead of 0 + img_cond_ids = torch.zeros(height // 2, width // 2, 3) + img_cond_ids[..., 0] = 1 + img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset + img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset + img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) + height_offset += height // 2 + width_offset += width // 2 + + if target_width is None: + target_width = 8 * width + if target_height is None: + target_height = 8 * height + img_cond_ids = img_cond_ids.to(device) + if cond_no == 0: + img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids + else: + img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1) + + return_dict = { + "img_cond_seq": img_cond_seq, + "img_cond_seq_ids": img_cond_seq_ids, + } + img = get_noise( + bs, + target_height, + target_width, + device=device, + dtype=torch.bfloat16, + seed=seed, + ) + return_dict.update(prepare_img(img)) + + return return_dict, target_height, target_width + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + real_guidance_scale = None, + # extra img tokens (channel-wise) + neg_txt: Tensor = None, + neg_txt_ids: Tensor= None, + neg_vec: Tensor = None, + img_cond: Tensor | None = None, + # extra img tokens (sequence-wise) + img_cond_seq: Tensor | None = None, + img_cond_seq_ids: Tensor | None = None, + siglip_embedding = None, + siglip_embedding_ids = None, + callback=None, + pipeline=None, + loras_slists=None, + unpack_latent = None, + joint_pass= False, +): + + kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids} + + if callback != None: + callback(-1, None, True) + + updated_num_steps= len(timesteps) -1 + if callback != None: + from shared.utils.loras_mutipliers import update_loras_slists + update_loras_slists(model, loras_slists, updated_num_steps) + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + from mmgp import offload + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + offload.set_step_no_for_lora(model, i) + if pipeline._interrupt: + return None + + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + img_input = img + img_input_ids = img_ids + if img_cond is not None: + img_input = torch.cat((img, img_cond), dim=-1) + if img_cond_seq is not None: + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + if not joint_pass or real_guidance_scale == 1: + pred = model( + img=img_input, + img_ids=img_input_ids, + txt_list=[txt], + txt_ids_list=[txt_ids], + y_list=[vec], + timesteps=t_vec, + guidance=guidance_vec, + **kwargs + )[0] + if pred == None: return None + if real_guidance_scale> 1: + neg_pred = model( + img=img_input, + img_ids=img_input_ids, + txt_list=[neg_txt], + txt_ids_list=[neg_txt_ids], + y_list=[neg_vec], + timesteps=t_vec, + guidance=guidance_vec, + **kwargs + )[0] + if neg_pred == None: return None + else: + pred, neg_pred = model( + img=img_input, + img_ids=img_input_ids, + txt_list=[txt, neg_txt], + txt_ids_list=[txt_ids, neg_txt_ids], + y_list=[vec, neg_vec], + timesteps=t_vec, + guidance=guidance_vec, + **kwargs + ) + if pred == None: return None + + if real_guidance_scale > 1: + pred = neg_pred + real_guidance_scale * (pred - neg_pred) + + img += (t_prev - t_curr) * pred + if callback is not None: + preview = unpack_latent(img).transpose(0,1) + callback(i, preview, False) + + + return img + +def prepare_multi_ip( + ae: AutoEncoder, + img_cond_list: list, + seed: int, + device: torch.device, + target_width: int | None = None, + target_height: int | None = None, + bs: int = 1, + pe: Literal["d", "h", "w", "o"] = "d", +) -> dict[str, Tensor]: + ref_imgs = img_cond_list + assert pe in ["d", "h", "w", "o"] + + ref_imgs = [ + ae.encode( + (TVF.to_tensor(ref_img) * 2.0 - 1.0) + .unsqueeze(0) + .to(device, torch.float32) + ).to(torch.bfloat16) + for ref_img in img_cond_list + ] + + img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed) + bs, c, h, w = img.shape + # tgt img + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + img_cond_seq = img_cond_seq_ids = None + pe_shift_w, pe_shift_h = w // 2, h // 2 + for cond_no, ref_img in enumerate(ref_imgs): + _, _, ref_h1, ref_w1 = ref_img.shape + ref_img = rearrange( + ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 + ) + if ref_img.shape[0] == 1 and bs > 1: + ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) + ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) + # img id分别在宽高偏移各自最大值 + h_offset = pe_shift_h if pe in {"d", "h"} else 0 + w_offset = pe_shift_w if pe in {"d", "w"} else 0 + ref_img_ids1[..., 1] = ( + ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset + ) + ref_img_ids1[..., 2] = ( + ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset + ) + ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) + + if target_width is None: + target_width = 8 * ref_w1 + if target_height is None: + target_height = 8 * ref_h1 + ref_img_ids1 = ref_img_ids1.to(device) + if cond_no == 0: + img_cond_seq, img_cond_seq_ids = ref_img, ref_img_ids1 + else: + img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1) + + + # 更新pe shift + pe_shift_h += ref_h1 // 2 + pe_shift_w += ref_w1 // 2 + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "img_cond_seq": img_cond_seq, + "img_cond_seq_ids": img_cond_seq_ids, + }, target_height, target_width + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/flux/util.py b/models/flux/util.py similarity index 92% rename from flux/util.py rename to models/flux/util.py index 9b477a0..0f96103 100644 --- a/flux/util.py +++ b/models/flux/util.py @@ -11,16 +11,13 @@ from huggingface_hub import hf_hub_download, login from PIL import ExifTags, Image from safetensors.torch import load_file as load_sft -from flux.model import Flux, FluxLoraWrapper, FluxParams -from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams -from flux.modules.conditioner import HFEmbedder +from .model import Flux, FluxLoraWrapper, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder CHECKPOINTS_DIR = Path("checkpoints") -CHECKPOINTS_DIR.mkdir(exist_ok=True) -BFL_API_KEY = os.getenv("BFL_API_KEY") -os.environ.setdefault("TRT_ENGINE_DIR", str(CHECKPOINTS_DIR / "trt_engines")) -(CHECKPOINTS_DIR / "trt_engines").mkdir(exist_ok=True) +BFL_API_KEY = os.getenv("BFL_API_KEY") def ensure_hf_auth(): @@ -358,6 +355,38 @@ configs = { shift_factor=0.1159, ), ), + "flux-chroma": ModelSpec( + repo_id="lodestones/Chroma1-HD", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + chroma=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), "flux-dev-canny": ModelSpec( repo_id="black-forest-labs/FLUX.1-Canny-dev", repo_flow="", @@ -579,6 +608,38 @@ configs = { shift_factor=0.1159, ), ), + "flux-dev-uso": ModelSpec( + repo_id="", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + eso= True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), } diff --git a/models/hyvideo/__init__.py b/models/hyvideo/__init__.py new file mode 100644 index 0000000..d3a3700 --- /dev/null +++ b/models/hyvideo/__init__.py @@ -0,0 +1,2 @@ +from .hunyuan import HunyuanVideoSampler +from . import hunyuan_handler \ No newline at end of file diff --git a/hyvideo/config.py b/models/hyvideo/config.py similarity index 100% rename from hyvideo/config.py rename to models/hyvideo/config.py diff --git a/hyvideo/constants.py b/models/hyvideo/constants.py similarity index 100% rename from hyvideo/constants.py rename to models/hyvideo/constants.py diff --git a/hyvideo/data_kits/audio_dataset.py b/models/hyvideo/data_kits/audio_dataset.py similarity index 100% rename from hyvideo/data_kits/audio_dataset.py rename to models/hyvideo/data_kits/audio_dataset.py diff --git a/hyvideo/data_kits/audio_preprocessor.py b/models/hyvideo/data_kits/audio_preprocessor.py similarity index 100% rename from hyvideo/data_kits/audio_preprocessor.py rename to models/hyvideo/data_kits/audio_preprocessor.py diff --git a/hyvideo/data_kits/data_tools.py b/models/hyvideo/data_kits/data_tools.py similarity index 100% rename from hyvideo/data_kits/data_tools.py rename to models/hyvideo/data_kits/data_tools.py diff --git a/hyvideo/data_kits/face_align/__init__.py b/models/hyvideo/data_kits/face_align/__init__.py similarity index 100% rename from hyvideo/data_kits/face_align/__init__.py rename to models/hyvideo/data_kits/face_align/__init__.py diff --git a/hyvideo/data_kits/face_align/align.py b/models/hyvideo/data_kits/face_align/align.py similarity index 100% rename from hyvideo/data_kits/face_align/align.py rename to models/hyvideo/data_kits/face_align/align.py diff --git a/hyvideo/data_kits/face_align/detface.py b/models/hyvideo/data_kits/face_align/detface.py similarity index 99% rename from hyvideo/data_kits/face_align/detface.py rename to models/hyvideo/data_kits/face_align/detface.py index d04d293..4885e15 100644 --- a/hyvideo/data_kits/face_align/detface.py +++ b/models/hyvideo/data_kits/face_align/detface.py @@ -249,7 +249,7 @@ class DetFace(): for scale in [8,16,32]: ny = h1//scale nx = w1//scale - yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float() grids.append(grid.to(self.test_device)) self.grids = grids diff --git a/hyvideo/diffusion/__init__.py b/models/hyvideo/diffusion/__init__.py similarity index 100% rename from hyvideo/diffusion/__init__.py rename to models/hyvideo/diffusion/__init__.py diff --git a/hyvideo/diffusion/pipelines/__init__.py b/models/hyvideo/diffusion/pipelines/__init__.py similarity index 100% rename from hyvideo/diffusion/pipelines/__init__.py rename to models/hyvideo/diffusion/pipelines/__init__.py diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py similarity index 98% rename from hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py rename to models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py index 22f652e..ed91f9c 100644 --- a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py @@ -949,15 +949,18 @@ class HunyuanVideoPipeline(DiffusionPipeline): # width = width or self.transformer.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks trans = self.transformer - if trans.enable_cache == "tea": - teacache_multiplier = trans.cache_multiplier - trans.accumulated_rel_l1_distance = 0 - trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 - elif trans.enable_cache == "mag": - trans.compute_magcache_threshold(trans.cache_start_step, num_inference_steps, trans.cache_multiplier) - trans.accumulated_err, trans.accumulated_steps, trans.accumulated_ratio = 0, 0, 1.0 - else: - trans.enable_cache == None + skip_steps_cache = trans.cache + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type + if cache_type == "tea": + teacache_multiplier = skip_steps_cache.multiplier + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 + elif cache_type== "mag": + trans.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0 + else: + trans.cache = None # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -1212,8 +1215,8 @@ class HunyuanVideoPipeline(DiffusionPipeline): if ip_cfg_scale>0: latent_items += 1 - if self.transformer.enable_cache: - self.transformer.previous_residual = [None] * latent_items + if skip_steps_cache != None: + skip_steps_cache.previous_residual = [None] * latent_items # if is_progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py similarity index 95% rename from hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py rename to models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py index 191f9ab..c043a12 100644 --- a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py +++ b/models/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py @@ -41,9 +41,9 @@ from diffusers.utils import ( from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from hyvideo.constants import PRECISION_TO_TYPE -from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from hyvideo.text_encoder import TextEncoder +from ...constants import PRECISION_TO_TYPE +from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from ...text_encoder import TextEncoder from einops import rearrange from ...modules import HYVideoDiffusionTransformer @@ -934,15 +934,20 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): transformer = self.transformer - if transformer.enable_cache == "tea": - teacache_multiplier = transformer.cache_multiplier - transformer.accumulated_rel_l1_distance = 0 - transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 - elif transformer.enable_cache == "mag": - transformer.compute_magcache_threshold(transformer.cache_start_step, num_inference_steps, transformer.cache_multiplier) - transformer.accumulated_err, transformer.accumulated_steps, transformer.accumulated_ratio = 0, 0, 1.0 - else: - transformer.enable_cache == None + skip_steps_cache = transformer.cache + cache_type = None + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type + if cache_type == "tea": + teacache_multiplier = skip_steps_cache.multiplier + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 + elif cache_type == "mag": + transformer.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0 + else: + transformer.cache = None + cache_type = None # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1141,16 +1146,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): if self._interrupt: return [None] - if transformer.enable_cache == "tea": + if cache_type == "tea": cache_size = round( infer_length / frames_per_batch ) - transformer.previous_residual = [None] * latent_items + skip_steps_cache.previous_residual = [None] * latent_items cache_all_previous_residual = [None] * latent_items cache_all_previous_modulated_input = None cache_should_calc = [True] * cache_size cache_accumulated_rel_l1_distance = [0.] * cache_size cache_teacache_skipped_steps = [0] * cache_size - elif transformer.enable_cache == "mag": - transformer.previous_residual = [None] * latent_items + elif cache_type == "mag": + skip_steps_cache.previous_residual = [None] * latent_items with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1187,16 +1192,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1) img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3] - if transformer.enable_cache == "tea" and cache_size > 1: + if cache_type == "tea" and cache_size > 1: for l in range(latent_items): if cache_all_previous_residual[l] != None: bsz = cache_all_previous_residual[l].shape[0] - transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + skip_steps_cache.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) if cache_all_previous_modulated_input != None: - transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) - transformer.should_calc = cache_should_calc[cache_slot_no] - transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] - transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] + skip_steps_cache.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + skip_steps_cache.should_calc = cache_should_calc[cache_slot_no] + skip_steps_cache.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] + skip_steps_cache.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] if self.do_classifier_free_guidance: @@ -1304,21 +1309,21 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline): pred_latents[:, :, p] += latents[:, :, iii] counter[:, :, p] += 1 - if transformer.enable_cache == "tea" and cache_size > 1: + if cache_type == "tea" and cache_size > 1: for l in range(latent_items): - if transformer.previous_residual[l] != None: - bsz = transformer.previous_residual[l].shape[0] + if skip_steps_cache.previous_residual[l] != None: + bsz = skip_steps_cache.previous_residual[l].shape[0] if cache_all_previous_residual[l] == None: - cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype) - cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) + cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=skip_steps_cache.previous_residual[l].device, dtype=skip_steps_cache.previous_residual[l].dtype) + cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) - if transformer.previous_modulated_input != None: + if skip_steps_cache.previous_modulated_input != None: if cache_all_previous_modulated_input == None: - cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype) - cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) - cache_should_calc[cache_slot_no] = transformer.should_calc - cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance - cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps + cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=skip_steps_cache.previous_modulated_input.device, dtype=skip_steps_cache.previous_modulated_input.dtype) + cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) + cache_should_calc[cache_slot_no] = skip_steps_cache.should_calc + cache_accumulated_rel_l1_distance[cache_slot_no] = skip_steps_cache.accumulated_rel_l1_distance + cache_teacache_skipped_steps[cache_slot_no] = skip_steps_cache.teacache_skipped_steps cache_slot_no += 1 diff --git a/hyvideo/diffusion/schedulers/__init__.py b/models/hyvideo/diffusion/schedulers/__init__.py similarity index 100% rename from hyvideo/diffusion/schedulers/__init__.py rename to models/hyvideo/diffusion/schedulers/__init__.py diff --git a/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py b/models/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py similarity index 100% rename from hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py rename to models/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py diff --git a/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py similarity index 76% rename from hyvideo/hunyuan.py rename to models/hyvideo/hunyuan.py index 83d94ea..a38a7bd 100644 --- a/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -8,24 +8,24 @@ from pathlib import Path from einops import rearrange import torch import torch.distributed as dist -from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V -from hyvideo.vae import load_vae -from hyvideo.modules import load_model -from hyvideo.text_encoder import TextEncoder -from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list -from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new -from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler -from hyvideo.diffusion.pipelines import HunyuanVideoPipeline -from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline +from .constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V +from .vae import load_vae +from .modules import load_model +from .text_encoder import TextEncoder +from .utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list +from .modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new +from .diffusion.schedulers import FlowMatchDiscreteScheduler +from .diffusion.pipelines import HunyuanVideoPipeline +from .diffusion.pipelines import HunyuanVideoAudioPipeline from PIL import Image import numpy as np import torchvision.transforms as transforms import cv2 -from wan.utils.utils import resize_lanczos, calculate_new_dimensions -from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask +from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image +from .data_kits.audio_preprocessor import encode_audio, get_facemask from transformers import WhisperModel from transformers import AutoFeatureExtractor -from hyvideo.data_kits.face_align import AlignImage +from .data_kits.face_align import AlignImage import librosa def get_audio_feature(feature_extractor, audio_path, duration): @@ -66,174 +66,174 @@ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): -def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) +# def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): +# num_images, num_image_patches, embed_dim = image_features.shape +# batch_size, sequence_length = input_ids.shape +# left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) +# # 1. Create a mask to know where special image tokens are +# special_image_token_mask = input_ids == self.config.image_token_index +# num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) +# # Compute the maximum embed dimension +# max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length +# batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] +# # 2. Compute the positions where text should be written +# # Calculate new positions for text tokens in merged image-text sequence. +# # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. +# # `torch.cumsum` computes how each image token shifts subsequent text token positions. +# # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. +# new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 +# nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] +# if left_padding: +# new_token_positions += nb_image_pad[:, None] # offset for left padding +# text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) +# # 3. Create the full embedding, already padded to the maximum position +# final_embedding = torch.zeros( +# batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device +# ) +# final_attention_mask = torch.zeros( +# batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device +# ) +# if labels is not None: +# final_labels = torch.full( +# (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device +# ) +# # In case the Vision model or the Language model has been offloaded to CPU, we need to manually +# # set the corresponding tensors into their correct target device. +# target_device = inputs_embeds.device +# batch_indices, non_image_indices, text_to_overwrite = ( +# batch_indices.to(target_device), +# non_image_indices.to(target_device), +# text_to_overwrite.to(target_device), +# ) +# attention_mask = attention_mask.to(target_device) - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] +# # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] +# # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features +# final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] +# final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] +# if labels is not None: +# final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) +# # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) +# image_to_overwrite = torch.full( +# (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device +# ) +# image_to_overwrite[batch_indices, text_to_overwrite] = False +# image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) +# if image_to_overwrite.sum() != image_features.shape[:-1].numel(): +# raise ValueError( +# f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" +# f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." +# ) - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) +# final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) +# final_attention_mask |= image_to_overwrite +# position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] +# # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. +# batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) +# indices_to_mask = new_token_positions[batch_indices, pad_indices] - final_embedding[batch_indices, indices_to_mask] = 0 +# final_embedding[batch_indices, indices_to_mask] = 0 - if labels is None: - final_labels = None +# if labels is None: +# final_labels = None - return final_embedding, final_attention_mask, final_labels, position_ids +# return final_embedding, final_attention_mask, final_labels, position_ids -def patched_llava_forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, -): - from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast +# def patched_llava_forward( +# self, +# input_ids: torch.LongTensor = None, +# pixel_values: torch.FloatTensor = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[List[torch.FloatTensor]] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# vision_feature_layer: Optional[int] = None, +# vision_feature_select_strategy: Optional[str] = None, +# labels: Optional[torch.LongTensor] = None, +# use_cache: Optional[bool] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# return_dict: Optional[bool] = None, +# cache_position: Optional[torch.LongTensor] = None, +# num_logits_to_keep: int = 0, +# ): +# from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) +# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions +# output_hidden_states = ( +# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states +# ) +# return_dict = return_dict if return_dict is not None else self.config.use_return_dict +# vision_feature_layer = ( +# vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer +# ) +# vision_feature_select_strategy = ( +# vision_feature_select_strategy +# if vision_feature_select_strategy is not None +# else self.config.vision_feature_select_strategy +# ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") +# if (input_ids is None) ^ (inputs_embeds is not None): +# raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) +# if pixel_values is not None and inputs_embeds is not None: +# raise ValueError( +# "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" +# ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) +# if inputs_embeds is None: +# inputs_embeds = self.get_input_embeddings()(input_ids) - image_features = None - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) +# image_features = None +# if pixel_values is not None: +# image_features = self.get_image_features( +# pixel_values=pixel_values, +# vision_feature_layer=vision_feature_layer, +# vision_feature_select_strategy=vision_feature_select_strategy, +# ) - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) +# inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( +# image_features, inputs_embeds, input_ids, attention_mask, labels +# ) +# cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - ) +# outputs = self.language_model( +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# return_dict=return_dict, +# cache_position=cache_position, +# num_logits_to_keep=num_logits_to_keep, +# ) - logits = outputs[0] +# logits = outputs[0] - loss = None +# loss = None - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output +# if not return_dict: +# output = (logits,) + outputs[1:] +# return (loss,) + output if loss is not None else output - return LlavaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) +# return LlavaCausalLMOutputWithPast( +# loss=loss, +# logits=logits, +# past_key_values=outputs.past_key_values, +# hidden_states=outputs.hidden_states, +# attentions=outputs.attentions, +# image_hidden_states=image_features if pixel_values is not None else None, +# ) def adapt_model(model, audio_block_name): modules_dict= { k: m for k, m in model.named_modules()} @@ -320,8 +320,8 @@ class Inference(object): device = "cuda" import transformers - transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) - transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features + # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) + # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features torch.set_grad_enabled(False) text_len = 512 @@ -720,7 +720,6 @@ class HunyuanVideoSampler(Inference): embedded_guidance_scale=6.0, batch_size=1, num_videos_per_prompt=1, - i2v_resolution="720p", image_start=None, enable_RIFLEx = False, i2v_condition_type: str = "token_replace", @@ -779,7 +778,7 @@ class HunyuanVideoSampler(Inference): raise ValueError( f"Seed must be an integer, a list of integers, or None, got {seed}." ) - from wan.utils.utils import seed_everything + from shared.utils.utils import seed_everything seed_everything(seed) generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] @@ -846,39 +845,13 @@ class HunyuanVideoSampler(Inference): denoise_strength = 0 ip_cfg_scale = 0 if i2v_mode: - if i2v_resolution == "720p": - bucket_hw_base_size = 960 - elif i2v_resolution == "540p": - bucket_hw_base_size = 720 - elif i2v_resolution == "360p": - bucket_hw_base_size = 480 - else: - raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") - - # semantic_images = [Image.open(i2v_image_path).convert('RGB')] - semantic_images = [image_start.convert('RGB')] # - origin_size = semantic_images[0].size - h, w = origin_size - h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - closest_size = (w, h) - # crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) - # aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list]) - # closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) - ref_image_transform = transforms.Compose([ - transforms.Resize(closest_size), - transforms.CenterCrop(closest_size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]) - ]) - - semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] - semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) - + semantic_images = convert_tensor_to_image(image_start) + semantic_image_pixel_values = image_start.unsqueeze(0).unsqueeze(2).to(self.device) with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W img_latents.mul_(self.pipeline.vae.config.scaling_factor) - target_height, target_width = closest_size + target_height, target_width = image_start.shape[1:] # ======================================================================== # Build Rope freqs @@ -983,7 +956,7 @@ class HunyuanVideoSampler(Inference): # out_latents= ref_latents / self.vae.config.scaling_factor # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] # image = image.clamp(-1, 1) - # from wan.utils.utils import cache_video + # from shared.utils.utils import cache_video # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) motion_pose = np.array([25] * 4) @@ -1065,3 +1038,6 @@ class HunyuanVideoSampler(Inference): samples = samples.squeeze(0) return samples + + + diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py new file mode 100644 index 0000000..d95bd7e --- /dev/null +++ b/models/hyvideo/hunyuan_handler.py @@ -0,0 +1,177 @@ +import torch + +def get_hunyuan_text_encoder_filename(text_encoder_quantization): + if text_encoder_quantization =="int8": + text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" + else: + text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" + + return text_encoder_filename + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + resolution = inputs["resolution"] + width, height = resolution.split("x") + pixels = int(width) * int(height) + + if cache_type == "mag": + skip_steps_cache.update({ + "magcache_thresh" : 0, + "magcache_K" : 2, + }) + if pixels >= 1280* 720: + skip_steps_cache.def_mag_ratios = [1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456] + else: + skip_steps_cache.def_mag_ratios = [1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592] + else: + skip_steps_cache.coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + + if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]: + fps = 25 + elif base_model_type in ["hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] = fps + extra_model_def["frames_minimum"] = 5 + extra_model_def["frames_steps"] = 4 + extra_model_def["sliding_window"] = False + if base_model_type in ["hunyuan", "hunyuan_i2v"]: + extra_model_def["embedded_guidance"] = True + else: + extra_model_def["guidance_max_phases"] = 1 + + extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"] + extra_model_def["tea_cache"] = True + extra_model_def["mag_cache"] = True + + if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True + + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: + extra_model_def["one_image_ref_needed"] = True + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] + + @staticmethod + def query_family_maps(): + models_eqv_map = { + } + + models_comp_map = { + "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], + } + + return models_eqv_map, models_comp_map + + @staticmethod + def query_model_family(): + return "hunyuan" + + @staticmethod + def query_family_infos(): + return {"hunyuan":(20, "Hunyuan Video")} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan") + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/HunyuanVideo", + "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], + "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], + ["detface.pt"], + [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) + ] + } + + @staticmethod + def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .hunyuan import HunyuanVideoSampler + from mmgp import offload + + hunyuan_model = HunyuanVideoSampler.from_pretrained( + model_filepath = model_filename, + model_type = model_type, + base_model_type = base_model_type, + text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), + dtype = dtype, + quantizeTransformer = quantizeTransformer, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } + + if hunyuan_model.wav2vec != None: + pipe["wav2vec"] = hunyuan_model.wav2vec + + + # if hunyuan_model.align_instance != None: + # pipe["align_instance"] = hunyuan_model.align_instance.facedet.model + + + from .modules.models import get_linear_split_map + + split_linear_modules_map = get_linear_split_map() + hunyuan_model.model.split_linear_modules_map = split_linear_modules_map + offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map ) + + + return hunyuan_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults["embedded_guidance_scale"]= 6.0 + + if base_model_type in ["hunyuan","hunyuan_i2v"]: + ui_defaults.update({ + "guidance_scale": 7.0, + }) + + elif base_model_type in ["hunyuan_custom"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "resolution": "1280x720", + "video_prompt_type": "I", + }) + elif base_model_type in ["hunyuan_custom_audio"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "I", + }) + elif base_model_type in ["hunyuan_custom_edit"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "MVAI", + "sliding_window_size": 129, + }) + elif base_model_type in ["hunyuan_avatar"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "remove_background_images_ref": 0, + "skip_steps_start_step_perc": 25, + "video_length": 129, + "video_prompt_type": "KI", + }) diff --git a/hyvideo/modules/__init__.py b/models/hyvideo/modules/__init__.py similarity index 100% rename from hyvideo/modules/__init__.py rename to models/hyvideo/modules/__init__.py diff --git a/hyvideo/modules/activation_layers.py b/models/hyvideo/modules/activation_layers.py similarity index 100% rename from hyvideo/modules/activation_layers.py rename to models/hyvideo/modules/activation_layers.py diff --git a/hyvideo/modules/attenion.py b/models/hyvideo/modules/attenion.py similarity index 100% rename from hyvideo/modules/attenion.py rename to models/hyvideo/modules/attenion.py diff --git a/hyvideo/modules/audio_adapters.py b/models/hyvideo/modules/audio_adapters.py similarity index 100% rename from hyvideo/modules/audio_adapters.py rename to models/hyvideo/modules/audio_adapters.py diff --git a/hyvideo/modules/embed_layers.py b/models/hyvideo/modules/embed_layers.py similarity index 100% rename from hyvideo/modules/embed_layers.py rename to models/hyvideo/modules/embed_layers.py diff --git a/hyvideo/modules/mlp_layers.py b/models/hyvideo/modules/mlp_layers.py similarity index 100% rename from hyvideo/modules/mlp_layers.py rename to models/hyvideo/modules/mlp_layers.py diff --git a/hyvideo/modules/models.py b/models/hyvideo/modules/models.py similarity index 93% rename from hyvideo/modules/models.py rename to models/hyvideo/modules/models.py index de50efc..4cdce4a 100644 --- a/hyvideo/modules/models.py +++ b/models/hyvideo/modules/models.py @@ -18,7 +18,7 @@ from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, appl from .token_refiner import SingleTokenRefiner import numpy as np from mmgp import offload -from wan.modules.attention import pay_attention +from shared.attention import pay_attention from .audio_adapters import AudioProjNet2, PerceiverAttentionCA def get_linear_split_map(): @@ -28,10 +28,6 @@ def get_linear_split_map(): "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]} } return split_linear_modules_map -try: - from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask -except ImportError: - BlockDiagonalPaddedKeysMask = None class MMDoubleStreamBlock(nn.Module): @@ -469,7 +465,7 @@ class MMSingleStreamBlock(nn.Module): del img_mod, txt_mod x_mod_shape = x_mod.shape x_mod = x_mod.view(-1, x_mod.shape[-1]) - chunk_size = int(x_mod_shape[1]/6) + chunk_size = int(x_mod.shape[0]/6) x_chunks = torch.split(x_mod, chunk_size) attn = attn.view(-1, attn.shape[-1]) attn_chunks =torch.split(attn, chunk_size) @@ -798,6 +794,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): block.disable_deterministic() def compute_magcache_threshold(self, start_step, num_inference_steps = 0, speed_factor =0): + skips_step_cache = self.cache + def nearest_interp(src_array, target_length): src_length = len(src_array) if target_length == 1: @@ -805,11 +803,11 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): scale = (src_length - 1) / (target_length - 1) mapped_indices = np.round(np.arange(target_length) * scale).astype(int) return src_array[mapped_indices] - - if len(self.def_mag_ratios) != num_inference_steps: - self.mag_ratios = nearest_interp(self.def_mag_ratios, num_inference_steps) + def_mag_ratios = np.array([1.0]+ skips_step_cache.def_mag_ratios) + if len(def_mag_ratios) != num_inference_steps: + skips_step_cache.mag_ratios = nearest_interp(def_mag_ratios, num_inference_steps) else: - self.mag_ratios = self.def_mag_ratios + skips_step_cache.mag_ratios = def_mag_ratios best_deltas = None best_threshold = 0.01 @@ -825,12 +823,12 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): if i<=start_step: skip = False else: - cur_mag_ratio = self.mag_ratios[i] # conditional and unconditional in one list + cur_mag_ratio = skips_step_cache.mag_ratios[i] # conditional and unconditional in one list accumulated_ratio *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_steps += 1 # skip steps plus 1 cur_skip_err = np.abs(1-accumulated_ratio) # skip error of current steps accumulated_err += cur_skip_err # accumulated error of multiple steps - if accumulated_err best_diff: break threshold += 0.01 - self.magcache_thresh = best_threshold + skips_step_cache.magcache_thresh = best_threshold print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{num_inference_steps/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold @@ -973,23 +971,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): attn_mask = None freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None - - - if self.enable_cache: + should_calc = True + skip_steps_cache = self.cache + if skip_steps_cache is not None: + cache_type = skip_steps_cache.cache_type if x_id == 0: - self.should_calc = True - if self.enable_cache == "mag": - if step_no > self.cache_start_step: - cur_mag_ratio = self.mag_ratios[step_no] - self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio - cur_skip_err = np.abs(1-self.accumulated_ratio) - self.accumulated_err += cur_skip_err - self.accumulated_steps += 1 - if self.accumulated_err<=self.magcache_thresh and self.accumulated_steps<=self.magcache_K: - self.should_calc = False - self.cache_skipped_steps += 1 + skip_steps_cache.should_calc = True + if cache_type == "mag": + if step_no > skip_steps_cache.start_step: + cur_mag_ratio = skip_steps_cache.mag_ratios[step_no] + skip_steps_cache.accumulated_ratio = skip_steps_cache.accumulated_ratio*cur_mag_ratio + cur_skip_err = np.abs(1-skip_steps_cache.accumulated_ratio) + skip_steps_cache.accumulated_err += cur_skip_err + skip_steps_cache.accumulated_steps += 1 + if skip_steps_cache.accumulated_err<=skip_steps_cache.magcache_thresh and skip_steps_cache.accumulated_steps<=skip_steps_cache.magcache_K: + skip_steps_cache.should_calc = False + skip_steps_cache.skipped_steps += 1 else: - self.accumulated_ratio, self.accumulated_steps, self.accumulated_err = 1.0, 0, 0 + skip_steps_cache.accumulated_ratio, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_err = 1.0, 0, 0 else: inp = img[0:1] vec_ = vec[0:1] @@ -998,26 +997,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): normed_inp = normed_inp.to(torch.bfloat16) modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale ) del normed_inp, img_mod1_shift, img_mod1_scale - if step_no <= self.cache_start_step or step_no == self.num_steps-1: - self.accumulated_rel_l1_distance = 0 - else: - coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] - rescale_func = np.poly1d(coefficients) - self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) - if self.accumulated_rel_l1_distance < self.rel_l1_thresh: - self.should_calc = False - self.cache_skipped_steps += 1 + if step_no <= skip_steps_cache.start_step or step_no == skip_steps_cache.num_steps-1: + skip_steps_cache.accumulated_rel_l1_distance = 0 + else: + rescale_func = np.poly1d(skip_steps_cache.coefficients) + skip_steps_cache.accumulated_rel_l1_distance += rescale_func(((modulated_inp-skip_steps_cache.previous_modulated_input).abs().mean() / skip_steps_cache.previous_modulated_input.abs().mean()).cpu().item()) + if skip_steps_cache.accumulated_rel_l1_distance < skip_steps_cache.rel_l1_thresh: + skip_steps_cache.should_calc = False + skip_steps_cache.skipped_steps += 1 else: - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - else: - self.should_calc = True + skip_steps_cache.accumulated_rel_l1_distance = 0 + skip_steps_cache.previous_modulated_input = modulated_inp + should_calc = skip_steps_cache.should_calc - if not self.should_calc: - img += self.previous_residual[x_id] + if not should_calc: + img += skip_steps_cache.previous_residual[x_id] else: - if self.enable_cache: - self.previous_residual[x_id] = None + if skip_steps_cache is not None: + skip_steps_cache.previous_residual[x_id] = None ori_img = img[0:1].clone() # --------------------- Pass through DiT blocks ------------------------ for layer_num, block in enumerate(self.double_blocks): @@ -1080,10 +1077,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): single_block_args = None # img = x[:, :img_seq_len, ...] - if self.enable_cache: + if skip_steps_cache is not None: if len(img) > 1: - self.previous_residual[0] = torch.empty_like(img) - for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])): + skip_steps_cache.previous_residual[0] = torch.empty_like(img) + for i, (x, residual) in enumerate(zip(img, skip_steps_cache.previous_residual[0])): if i < len(img) - 1: residual[...] = torch.sub(x, ori_img) else: @@ -1091,8 +1088,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): torch.sub(x, ori_img, out=residual) x = None else: - self.previous_residual[x_id] = ori_img - torch.sub(img, ori_img, out=self.previous_residual[x_id]) + skip_steps_cache.previous_residual[x_id] = ori_img + torch.sub(img, ori_img, out=skip_steps_cache.previous_residual[x_id]) if ref_length != None: diff --git a/hyvideo/modules/modulate_layers.py b/models/hyvideo/modules/modulate_layers.py similarity index 100% rename from hyvideo/modules/modulate_layers.py rename to models/hyvideo/modules/modulate_layers.py diff --git a/hyvideo/modules/norm_layers.py b/models/hyvideo/modules/norm_layers.py similarity index 100% rename from hyvideo/modules/norm_layers.py rename to models/hyvideo/modules/norm_layers.py diff --git a/hyvideo/modules/original models.py b/models/hyvideo/modules/original models.py similarity index 100% rename from hyvideo/modules/original models.py rename to models/hyvideo/modules/original models.py diff --git a/hyvideo/modules/placement.py b/models/hyvideo/modules/placement.py similarity index 100% rename from hyvideo/modules/placement.py rename to models/hyvideo/modules/placement.py diff --git a/hyvideo/modules/posemb_layers.py b/models/hyvideo/modules/posemb_layers.py similarity index 100% rename from hyvideo/modules/posemb_layers.py rename to models/hyvideo/modules/posemb_layers.py diff --git a/hyvideo/modules/token_refiner.py b/models/hyvideo/modules/token_refiner.py similarity index 100% rename from hyvideo/modules/token_refiner.py rename to models/hyvideo/modules/token_refiner.py diff --git a/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py similarity index 100% rename from hyvideo/modules/utils.py rename to models/hyvideo/modules/utils.py diff --git a/hyvideo/prompt_rewrite.py b/models/hyvideo/prompt_rewrite.py similarity index 100% rename from hyvideo/prompt_rewrite.py rename to models/hyvideo/prompt_rewrite.py diff --git a/hyvideo/text_encoder/__init__.py b/models/hyvideo/text_encoder/__init__.py similarity index 97% rename from hyvideo/text_encoder/__init__.py rename to models/hyvideo/text_encoder/__init__.py index 1376718..9bd47d4 100644 --- a/hyvideo/text_encoder/__init__.py +++ b/models/hyvideo/text_encoder/__init__.py @@ -15,6 +15,7 @@ from transformers.utils import ModelOutput from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH from ..constants import PRECISION_TO_TYPE +from .llava.modeling_llava import LlavaForConditionalGeneration def use_default(value, default): @@ -188,10 +189,16 @@ class TextEncoder(nn.Module): if "llm" in text_encoder_type: from mmgp import offload - forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" - self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) - if forcedConfigPath != None: + # forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" + # self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) + + if "i2v" in text_encoder_type: + self.model= offload.fast_load_transformers_model(self.model_path, modelClass= LlavaForConditionalGeneration) + else: + self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model", forcedConfigPath = "ckpts/llava-llama-3-8b/config.json") self.model.final_layer_norm = self.model.model.norm + + else: self.model, self.model_path = load_text_encoder( diff --git a/models/hyvideo/text_encoder/llava/__init__.py b/models/hyvideo/text_encoder/llava/__init__.py new file mode 100644 index 0000000..e6d2f52 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from typing import TYPE_CHECKING + +# from ...utils import _LazyModule +# from ...utils.import_utils import define_import_structure + + +# if TYPE_CHECKING: +# from .configuration_llava import * +# from .image_processing_llava_fast import * +# from .modeling_llava import * +# from .processing_llava import * +# else: +# import sys + +# _file = globals()["__file__"] + # sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/models/hyvideo/text_encoder/llava/configuration_llava.py b/models/hyvideo/text_encoder/llava/configuration_llava.py new file mode 100644 index 0000000..9c30798 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/configuration_llava.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llava model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.models.auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class LlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an + Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llava-9B. + + e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llava" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + is_composition = True + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaConfig"] diff --git a/models/hyvideo/text_encoder/llava/image_processing_llava.py b/models/hyvideo/text_encoder/llava/image_processing_llava.py new file mode 100644 index 0000000..37ef079 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/image_processing_llava.py @@ -0,0 +1,436 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for LLaVa.""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class LlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a LLaVa image processor. + + Args: + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image to a square based on the longest edge. + The padding value is determined by the `image_mean` parameter. + Can be overridden by `do_pad` in the `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_pad: bool = False, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_pad = do_pad + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_pad", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def pad_to_square( + self, + image: np.ndarray, + background_color: Union[int, Tuple[int, int, int]] = 0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Pads an image to a square based on the longest edge. + + Args: + image (`np.ndarray`): + The image to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + """ + height, width = get_image_size(image, input_data_format) + num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1] + + if height == width: + image = ( + to_channel_dimension_format(image, data_format, input_data_format) + if data_format is not None + else image + ) + return image + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + if input_data_format == ChannelDimension.FIRST: + result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype) + for i, color in enumerate(background_color): + result[i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + result[:, start : start + height, :] = image + else: + start = (max_dim - width) // 2 + result[:, :, start : start + width] = image + else: + result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype) + for i, color in enumerate(background_color): + result[:, :, i] = color + if width > height: + start = (max_dim - height) // 2 + result[start : start + height, :, :] = image + else: + start = (max_dim - width) // 2 + result[:, start : start + width, :] = image + + image = ( + to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result + ) + return image + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_pad: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. + The padding value is determined by the `image_mean` parameter. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_pad = do_pad if do_pad is not None else self.do_pad + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + # we don't pass `do_pad` here since LLaVa uses a custom padding to a square + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + processed_images = [] + for image in images: + if do_pad: + image = self.pad_to_square( + image=image, + background_color=tuple(int(x * 255) for x in self.image_mean), + input_data_format=input_data_format, + ) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["LlavaImageProcessor"] diff --git a/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py b/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py new file mode 100644 index 0000000..d85eb89 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/image_processing_llava_fast.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LLaVa.""" + +from typing import List, Optional, Tuple, Union + +from ...image_processing_utils import ( + BatchFeature, +) +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, +) + + +if is_vision_available(): + from ...image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_pad: Optional[bool] + + +@add_start_docstrings( + "Constructs a fast Llava image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter + """, +) +class LlavaImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_pad = False + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = LlavaFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, Tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + height, width = get_image_size(images, ChannelDimension.FIRST) + + if height == width: + return images + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + max_dim = max(height, width) + paste_x_left = (max_dim - width) // 2 + paste_y_left = (max_dim - height) // 2 + paste_x_right = max_dim - width - paste_x_left + paste_y_right = max_dim - height - paste_y_left + padded_images = F.pad( + images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color + ) + + return padded_images + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_pad: bool, + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_pad: + stacked_images = self.pad_to_square( + images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean) + ) + resized_images_grouped[shape] = stacked_images + padded_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for batched resizing + # Needed in case do_pad is False, or padding returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(padded_images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["LlavaImageProcessorFast"] diff --git a/models/hyvideo/text_encoder/llava/modeling_llava.py b/models/hyvideo/text_encoder/llava/modeling_llava.py new file mode 100644 index 0000000..f4ae058 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/modeling_llava.py @@ -0,0 +1,531 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Llava model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import ModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.models.auto import AutoModel, AutoModelForCausalLM +from .configuration_llava import LlavaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlavaConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf" + + +@dataclass +class LlavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_START_DOCSTRING, +) +class LlavaPreTrainedModel(PreTrainedModel): + config_class = LlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + LLAVA_START_DOCSTRING, +) +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): + def __init__(self, config: LlavaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + # For default; crop CLS from each hidden state in the hidden state pool + if vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + # @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + # @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) + # @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"] diff --git a/models/hyvideo/text_encoder/llava/processing_llava.py b/models/hyvideo/text_encoder/llava/processing_llava.py new file mode 100644 index 0000000..6253e19 --- /dev/null +++ b/models/hyvideo/text_encoder/llava/processing_llava.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + +from typing import List, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class LlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + +class LlavaProcessor(ProcessorMixin): + r""" + Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor. + + [`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. + + Args: + image_processor ([`LlavaImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*): + Patch size from the vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Shoudl be same as in model's config + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "num_additional_image_tokens", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size=None, + vision_feature_select_strategy=None, + chat_template=None, + image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + num_additional_image_tokens=0, + **kwargs, + ): + self.patch_size = patch_size + self.num_additional_image_tokens = num_additional_image_tokens + self.vision_feature_select_strategy = vision_feature_select_strategy + self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[LlavaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + LlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * ( + width // self.patch_size + ) + self.num_additional_image_tokens + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) + + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return BatchFeature(data={**text_inputs, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["LlavaProcessor"] diff --git a/hyvideo/utils/__init__.py b/models/hyvideo/utils/__init__.py similarity index 100% rename from hyvideo/utils/__init__.py rename to models/hyvideo/utils/__init__.py diff --git a/hyvideo/utils/data_utils.py b/models/hyvideo/utils/data_utils.py similarity index 100% rename from hyvideo/utils/data_utils.py rename to models/hyvideo/utils/data_utils.py diff --git a/hyvideo/utils/file_utils.py b/models/hyvideo/utils/file_utils.py similarity index 100% rename from hyvideo/utils/file_utils.py rename to models/hyvideo/utils/file_utils.py diff --git a/hyvideo/utils/helpers.py b/models/hyvideo/utils/helpers.py similarity index 100% rename from hyvideo/utils/helpers.py rename to models/hyvideo/utils/helpers.py diff --git a/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py b/models/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py similarity index 100% rename from hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py rename to models/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py diff --git a/hyvideo/vae/__init__.py b/models/hyvideo/vae/__init__.py similarity index 100% rename from hyvideo/vae/__init__.py rename to models/hyvideo/vae/__init__.py diff --git a/hyvideo/vae/autoencoder_kl_causal_3d.py b/models/hyvideo/vae/autoencoder_kl_causal_3d.py similarity index 100% rename from hyvideo/vae/autoencoder_kl_causal_3d.py rename to models/hyvideo/vae/autoencoder_kl_causal_3d.py diff --git a/hyvideo/vae/unet_causal_3d_blocks.py b/models/hyvideo/vae/unet_causal_3d_blocks.py similarity index 100% rename from hyvideo/vae/unet_causal_3d_blocks.py rename to models/hyvideo/vae/unet_causal_3d_blocks.py diff --git a/hyvideo/vae/vae.py b/models/hyvideo/vae/vae.py similarity index 100% rename from hyvideo/vae/vae.py rename to models/hyvideo/vae/vae.py diff --git a/models/ltx_video/__init__.py b/models/ltx_video/__init__.py new file mode 100644 index 0000000..3a3898e --- /dev/null +++ b/models/ltx_video/__init__.py @@ -0,0 +1,2 @@ +from .ltxv import LTXV +from . import ltxv_handler \ No newline at end of file diff --git a/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-dev.original.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-dev.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-dev.yaml diff --git a/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml b/models/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml similarity index 100% rename from ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml rename to models/ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml diff --git a/models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml b/models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml new file mode 100644 index 0000000..0c22e9e --- /dev/null +++ b/models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml @@ -0,0 +1,34 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + cfg_star_rescale: true + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + cfg_star_rescale: true \ No newline at end of file diff --git a/models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml b/models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml new file mode 100644 index 0000000..a1ac723 --- /dev/null +++ b/models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml @@ -0,0 +1,29 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + tone_map_compression_ratio: 0.6 diff --git a/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml b/models/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml similarity index 100% rename from ltx_video/configs/ltxv-2b-0.9.6-dev.yaml rename to models/ltx_video/configs/ltxv-2b-0.9.6-dev.yaml diff --git a/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml b/models/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml similarity index 100% rename from ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml rename to models/ltx_video/configs/ltxv-2b-0.9.6-distilled.yaml diff --git a/ltx_video/ltxv.py b/models/ltx_video/ltxv.py similarity index 86% rename from ltx_video/ltxv.py rename to models/ltx_video/ltxv.py index 6b43c38..e71ac4f 100644 --- a/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -7,7 +7,7 @@ from pathlib import Path from diffusers.utils import logging from typing import Optional, List, Union import yaml -from wan.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions import imageio import json import numpy as np @@ -149,6 +149,7 @@ class LTXV: self, model_filepath: str, text_encoder_filepath: str, + model_type, base_model_type, model_def, dtype = torch.bfloat16, VAE_dtype = torch.bfloat16, @@ -159,24 +160,31 @@ class LTXV: dtype = torch.bfloat16 self.mixed_precision_transformer = mixed_precision_transformer self.model_def = model_def - self.pipeline_config = model_def["LTXV_config"] + self.model_type = model_type + self.pipeline_config = model_def["LTXV_config"] + # ckpt_path ="c:/temp/ltxv-13b-0.9.8-dev.safetensors" # with safe_open(ckpt_path, framework="pt") as f: # metadata = f.metadata() # config_str = metadata.get("config") # configs = json.loads(config_str) - # allowed_inference_steps = configs.get("allowed_inference_steps", None) + # allowed_inference_steps = configs.get("allowed_inference_steps", None) + # with open("c:/temp/ltxv_config.json", "w", encoding="utf-8") as writer: + # writer.write(json.dumps(configs["transformer"])) + # with open("c:/temp/vae_config.json", "w", encoding="utf-8") as writer: + # writer.write(json.dumps(configs["vae"])) # transformer = Transformer3DModel.from_pretrained(ckpt_path) - # transformer = offload.fast_load_transformers_model("c:/temp/ltxdistilled/diffusion_pytorch_model-00001-of-00006.safetensors", forcedConfigPath="c:/temp/ltxdistilled/config.json") - + # offload.save_model(transformer, "ckpts/ltxv_0.9.8_13B_bf16.safetensors", config_file_path= "c:/temp/ltxv_config.json") + # offload.save_model(transformer, "ckpts/ltxv_0.9.8_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path= "c:/temp/ltxv_config.json") + # vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder) + # vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.8_VAE.safetensors", modelClass=CausalVideoAutoencoder) # if VAE_dtype == torch.float16: VAE_dtype = torch.bfloat16 vae = vae.to(VAE_dtype) vae._model_dtype = VAE_dtype - # vae = offload.fast_load_transformers_model("vae.safetensors", modelClass=CausalVideoAutoencoder, modelPrefix= "vae", forcedConfigPath="config_vae.json") - # offload.save_model(vae, "vae.safetensors", config_file_path="config_vae.json") + # offload.save_model(vae, "vae.safetensors", config_file_path="c:/temp/config_vae.json") # model_filepath = "c:/temp/ltxd/ltxv-13b-0.9.7-distilled.safetensors" transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel) @@ -193,6 +201,7 @@ class LTXV: # offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json") latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval() + # latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.8_spatial_upscaler.safetensors").to("cpu").eval() latent_upsampler.to(VAE_dtype) latent_upsampler._model_dtype = VAE_dtype @@ -259,6 +268,7 @@ class LTXV: image_start = None, image_end = None, input_video = None, + input_frames = None, sampling_steps = 50, image_cond_noise_scale: float = 0.15, input_media_path: Optional[str] = None, @@ -272,6 +282,7 @@ class LTXV: callback=None, device: Optional[str] = None, VAE_tile_size = None, + apg_switch = 0, **kwargs, ): @@ -280,21 +291,34 @@ class LTXV: conditioning_strengths = None conditioning_media_paths = [] conditioning_start_frames = [] - - + conditioning_control_frames = [] + prefix_size = 0 if input_video != None: conditioning_media_paths.append(input_video) conditioning_start_frames.append(0) - height, width = input_video.shape[-2:] + conditioning_control_frames.append(False) + prefix_size, height, width = input_video.shape[-3:] else: if image_start != None: frame_width, frame_height = image_start.size - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) - conditioning_media_paths.append(image_start) + if fit_into_canvas != None: + height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) + conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_start_frames.append(0) - if image_end != None: - conditioning_media_paths.append(image_end) - conditioning_start_frames.append(frame_num-1) + conditioning_control_frames.append(False) + prefix_size = 1 + + if image_end != None: + conditioning_media_paths.append(image_end.unsqueeze(1)) + conditioning_start_frames.append(frame_num-1) + conditioning_control_frames.append(False) + + if input_frames!= None: + conditioning_media_paths.append(input_frames) + conditioning_start_frames.append(prefix_size) + conditioning_control_frames.append(True) + height, width = input_frames.shape[-2:] + fit_into_canvas = None if len(conditioning_media_paths) == 0: conditioning_media_paths = None @@ -380,6 +404,7 @@ class LTXV: conditioning_media_paths=conditioning_media_paths, conditioning_strengths=conditioning_strengths, conditioning_start_frames=conditioning_start_frames, + conditioning_control_frames=conditioning_control_frames, height=height, width=width, num_frames=frame_num, @@ -435,6 +460,7 @@ class LTXV: mixed_precision=pipeline_config.get("mixed", self.mixed_precision_transformer), callback=callback, VAE_tile_size = VAE_tile_size, + apg_switch = apg_switch, device=device, # enhance_prompt=enhance_prompt, ) @@ -453,11 +479,29 @@ class LTXV: images = images.sub_(0.5).mul_(2).squeeze(0) return images + def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs): + map = { + "P" : "pose", + "D" : "depth", + "E" : "canny", + } + loras = [] + preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs") + lora_file_name = "" + for letter, signature in map.items(): + if letter in video_prompt_type: + for file_name in preloadURLs: + if signature in file_name: + loras += [ os.path.join("ckpts", os.path.basename(file_name))] + break + loras_mult = [1.] * len(loras) + return loras, loras_mult def prepare_conditioning( conditioning_media_paths: List[str], conditioning_strengths: List[float], conditioning_start_frames: List[int], + conditioning_control_frames: List[int], height: int, width: int, num_frames: int, @@ -480,8 +524,8 @@ def prepare_conditioning( A list of ConditioningItem objects. """ conditioning_items = [] - for path, strength, start_frame in zip( - conditioning_media_paths, conditioning_strengths, conditioning_start_frames + for path, strength, start_frame, conditioning_control_frames in zip( + conditioning_media_paths, conditioning_strengths, conditioning_start_frames, conditioning_control_frames ): if isinstance(path, Image.Image): num_input_frames = orig_num_input_frames = 1 @@ -506,7 +550,7 @@ def prepare_conditioning( padding=padding, just_crop=True, ) - conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) + conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength, conditioning_control_frames)) return conditioning_items @@ -561,3 +605,4 @@ def load_media_file( raise Exception("video format not supported") return media_tensor + diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py new file mode 100644 index 0000000..d35bcd4 --- /dev/null +++ b/models/ltx_video/ltxv_handler.py @@ -0,0 +1,90 @@ +import torch + + +def get_ltxv_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + LTXV_config = model_def.get("LTXV_config", "") + distilled= "distilled" in LTXV_config + extra_model_def = {} + if distilled: + extra_model_def.update({ + "lock_inference_steps": True, + "no_negative_prompt" : True, + }) + + + extra_model_def["fps"] = 30 + extra_model_def["frames_minimum"] = 17 + extra_model_def["frames_steps"] = 8 + extra_model_def["sliding_window"] = True + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["ltxv_13B"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv") + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_family(): + return "ltxv" + + @staticmethod + def query_family_infos(): + return {"ltxv":(10, "LTX Video")} + + @staticmethod + def get_vae_block_size(base_model_type): + return 32 + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : ["T5_xxl_1.1", "" ], + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] + } + + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .ltxv import LTXV + + ltxv_model = LTXV( + model_filepath = model_filename, + text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), + model_type = model_type, + base_model_type = base_model_type, + model_def = model_def, + dtype = dtype, + # quantizeTransformer = quantizeTransformer, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer + ) + + pipeline = ltxv_model.pipeline + pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler} + + return ltxv_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + pass + \ No newline at end of file diff --git a/ltx_video/__init__.py b/models/ltx_video/models/__init__.py similarity index 100% rename from ltx_video/__init__.py rename to models/ltx_video/models/__init__.py diff --git a/ltx_video/models/__init__.py b/models/ltx_video/models/autoencoders/__init__.py similarity index 100% rename from ltx_video/models/__init__.py rename to models/ltx_video/models/autoencoders/__init__.py diff --git a/ltx_video/models/autoencoders/causal_conv3d.py b/models/ltx_video/models/autoencoders/causal_conv3d.py similarity index 100% rename from ltx_video/models/autoencoders/causal_conv3d.py rename to models/ltx_video/models/autoencoders/causal_conv3d.py diff --git a/ltx_video/models/autoencoders/causal_video_autoencoder.py b/models/ltx_video/models/autoencoders/causal_video_autoencoder.py similarity index 97% rename from ltx_video/models/autoencoders/causal_video_autoencoder.py rename to models/ltx_video/models/autoencoders/causal_video_autoencoder.py index 0edfe6a..daed704 100644 --- a/ltx_video/models/autoencoders/causal_video_autoencoder.py +++ b/models/ltx_video/models/autoencoders/causal_video_autoencoder.py @@ -15,12 +15,12 @@ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbedding from safetensors import safe_open -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from ltx_video.models.autoencoders.pixel_norm import PixelNorm -from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND -from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper -from ltx_video.models.transformers.attention import Attention -from ltx_video.utils.diffusers_config_mapping import ( +from ..autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ...models.autoencoders.pixel_norm import PixelNorm +from ...models.autoencoders.pixel_shuffle import PixelShuffleND +from ...models.autoencoders.vae import AutoencoderKLWrapper +from ...models.transformers.attention import Attention +from ...utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, VAE_KEYS_RENAME_DICT, @@ -253,10 +253,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper): if key.startswith("vae.") } + + stats_keys_to_keep = ["per_channel_statistics.std-of-means", "per_channel_statistics.mean-of-means"] ckpt_state_dict = { key: value for key, value in state_dict.items() - if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) or key in stats_keys_to_keep } model_keys = set(name for name, _ in self.named_modules()) @@ -280,21 +282,26 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper): converted_state_dict[key] = value + # data_dict = { + # key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + # for key, value in state_dict.items() + # if key in stats_keys_to_keep + # } + for key in stats_keys_to_keep: + if key in converted_state_dict: # happens only in the original vae sd + v = converted_state_dict.pop(key) + converted_state_dict[key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX).replace("-", "_")] = v + a,b = super().load_state_dict(converted_state_dict, strict=strict, assign=assign) - data_dict = { - key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value - for key, value in state_dict.items() - if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - if len(data_dict) > 0: - self.register_buffer("std_of_means", data_dict["std-of-means"],) - self.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) + # if len(data_dict) > 0: + # self.register_buffer("std_of_means", data_dict["std-of-means"],) + # self.register_buffer( + # "mean_of_means", + # data_dict.get( + # "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + # ), + # ) return a, b def last_layer(self): diff --git a/ltx_video/models/autoencoders/conv_nd_factory.py b/models/ltx_video/models/autoencoders/conv_nd_factory.py similarity index 94% rename from ltx_video/models/autoencoders/conv_nd_factory.py rename to models/ltx_video/models/autoencoders/conv_nd_factory.py index 718c69b..59a3fc0 100644 --- a/ltx_video/models/autoencoders/conv_nd_factory.py +++ b/models/ltx_video/models/autoencoders/conv_nd_factory.py @@ -2,8 +2,8 @@ from typing import Tuple, Union import torch -from ltx_video.models.autoencoders.dual_conv3d import DualConv3d -from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d +from ..autoencoders.dual_conv3d import DualConv3d +from ..autoencoders.causal_conv3d import CausalConv3d def make_conv_nd( diff --git a/ltx_video/models/autoencoders/dual_conv3d.py b/models/ltx_video/models/autoencoders/dual_conv3d.py similarity index 100% rename from ltx_video/models/autoencoders/dual_conv3d.py rename to models/ltx_video/models/autoencoders/dual_conv3d.py diff --git a/ltx_video/models/autoencoders/latent_upsampler.py b/models/ltx_video/models/autoencoders/latent_upsampler.py similarity index 98% rename from ltx_video/models/autoencoders/latent_upsampler.py rename to models/ltx_video/models/autoencoders/latent_upsampler.py index 4a76bc2..f666d2f 100644 --- a/ltx_video/models/autoencoders/latent_upsampler.py +++ b/models/ltx_video/models/autoencoders/latent_upsampler.py @@ -9,7 +9,7 @@ from einops import rearrange from diffusers import ConfigMixin, ModelMixin from safetensors.torch import safe_open -from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND +from ...models.autoencoders.pixel_shuffle import PixelShuffleND class ResBlock(nn.Module): diff --git a/ltx_video/models/autoencoders/pixel_norm.py b/models/ltx_video/models/autoencoders/pixel_norm.py similarity index 100% rename from ltx_video/models/autoencoders/pixel_norm.py rename to models/ltx_video/models/autoencoders/pixel_norm.py diff --git a/ltx_video/models/autoencoders/pixel_shuffle.py b/models/ltx_video/models/autoencoders/pixel_shuffle.py similarity index 100% rename from ltx_video/models/autoencoders/pixel_shuffle.py rename to models/ltx_video/models/autoencoders/pixel_shuffle.py diff --git a/ltx_video/models/autoencoders/vae.py b/models/ltx_video/models/autoencoders/vae.py similarity index 97% rename from ltx_video/models/autoencoders/vae.py rename to models/ltx_video/models/autoencoders/vae.py index 5b19ba4..a0ce1c4 100644 --- a/ltx_video/models/autoencoders/vae.py +++ b/models/ltx_video/models/autoencoders/vae.py @@ -10,7 +10,7 @@ from diffusers.models.autoencoders.vae import ( DiagonalGaussianDistribution, ) from diffusers.models.modeling_outputs import AutoencoderKLOutput -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd +from ...models.autoencoders.conv_nd_factory import make_conv_nd class AutoencoderKLWrapper(ModelMixin, ConfigMixin): @@ -44,12 +44,17 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin): self.per_channel_statistics = nn.Module() std_of_means = torch.zeros( (128,), dtype= torch.bfloat16) - self.per_channel_statistics.register_buffer("std-of-means", std_of_means) - self.per_channel_statistics.register_buffer( - "mean-of-means", + # self.per_channel_statistics.register_buffer("std-of-means", std_of_means) + # self.per_channel_statistics.register_buffer( + # "mean-of-means", + # torch.zeros_like(std_of_means) + # ) + + self.register_buffer("std_of_means", std_of_means) + self.register_buffer( + "mean_of_means", torch.zeros_like(std_of_means) ) - # pass init params to Encoder diff --git a/ltx_video/models/autoencoders/vae_encode.py b/models/ltx_video/models/autoencoders/vae_encode.py similarity index 98% rename from ltx_video/models/autoencoders/vae_encode.py rename to models/ltx_video/models/autoencoders/vae_encode.py index b7d2476..4b6a5c4 100644 --- a/ltx_video/models/autoencoders/vae_encode.py +++ b/models/ltx_video/models/autoencoders/vae_encode.py @@ -5,10 +5,10 @@ from einops import rearrange from torch import Tensor -from ltx_video.models.autoencoders.causal_video_autoencoder import ( +from ...models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from ltx_video.models.autoencoders.video_autoencoder import ( +from ...models.autoencoders.video_autoencoder import ( Downsample3D, VideoAutoencoder, ) diff --git a/ltx_video/models/autoencoders/video_autoencoder.py b/models/ltx_video/models/autoencoders/video_autoencoder.py similarity index 99% rename from ltx_video/models/autoencoders/video_autoencoder.py rename to models/ltx_video/models/autoencoders/video_autoencoder.py index 3c7926c..dbb2bcd 100644 --- a/ltx_video/models/autoencoders/video_autoencoder.py +++ b/models/ltx_video/models/autoencoders/video_autoencoder.py @@ -11,10 +11,10 @@ from torch.nn import functional from diffusers.utils import logging -from ltx_video.utils.torch_utils import Identity -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from ltx_video.models.autoencoders.pixel_norm import PixelNorm -from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper +from ...utils.torch_utils import Identity +from ...models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ...models.autoencoders.pixel_norm import PixelNorm +from ...models.autoencoders.vae import AutoencoderKLWrapper logger = logging.get_logger(__name__) diff --git a/ltx_video/models/autoencoders/__init__.py b/models/ltx_video/models/transformers/__init__.py similarity index 100% rename from ltx_video/models/autoencoders/__init__.py rename to models/ltx_video/models/transformers/__init__.py diff --git a/ltx_video/models/transformers/attention.py b/models/ltx_video/models/transformers/attention.py similarity index 99% rename from ltx_video/models/transformers/attention.py rename to models/ltx_video/models/transformers/attention.py index a7b4555..a87a8a0 100644 --- a/ltx_video/models/transformers/attention.py +++ b/models/ltx_video/models/transformers/attention.py @@ -19,15 +19,9 @@ from diffusers.utils import deprecate, logging from diffusers.utils.torch_utils import maybe_allow_in_graph from einops import rearrange from torch import nn -from wan.modules.attention import pay_attention -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from shared.attention import pay_attention +from ...utils.skip_layer_strategy import SkipLayerStrategy -try: - from torch_xla.experimental.custom_kernel import flash_attention -except ImportError: - # workaround for automatic tests. Currently this function is manually patched - # to the torch_xla lib on setup of container - pass # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py diff --git a/ltx_video/models/transformers/embeddings.py b/models/ltx_video/models/transformers/embeddings.py similarity index 100% rename from ltx_video/models/transformers/embeddings.py rename to models/ltx_video/models/transformers/embeddings.py diff --git a/ltx_video/models/transformers/symmetric_patchifier.py b/models/ltx_video/models/transformers/symmetric_patchifier.py similarity index 100% rename from ltx_video/models/transformers/symmetric_patchifier.py rename to models/ltx_video/models/transformers/symmetric_patchifier.py diff --git a/ltx_video/models/transformers/transformer3d.py b/models/ltx_video/models/transformers/transformer3d.py similarity index 98% rename from ltx_video/models/transformers/transformer3d.py rename to models/ltx_video/models/transformers/transformer3d.py index e182f21..c90baeb 100644 --- a/ltx_video/models/transformers/transformer3d.py +++ b/models/ltx_video/models/transformers/transformer3d.py @@ -16,10 +16,10 @@ from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils import logging from torch import nn from safetensors import safe_open -from ltx_video.models.transformers.attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from .attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape +from ...utils.skip_layer_strategy import SkipLayerStrategy -from ltx_video.utils.diffusers_config_mapping import ( +from ...utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, TRANSFORMER_KEYS_RENAME_DICT, diff --git a/ltx_video/models/transformers/__init__.py b/models/ltx_video/pipelines/__init__.py similarity index 100% rename from ltx_video/models/transformers/__init__.py rename to models/ltx_video/pipelines/__init__.py diff --git a/ltx_video/pipelines/crf_compressor.py b/models/ltx_video/pipelines/crf_compressor.py similarity index 100% rename from ltx_video/pipelines/crf_compressor.py rename to models/ltx_video/pipelines/crf_compressor.py diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/models/ltx_video/pipelines/pipeline_ltx_video.py similarity index 90% rename from ltx_video/pipelines/pipeline_ltx_video.py rename to models/ltx_video/pipelines/pipeline_ltx_video.py index 38ff702..f98eb13 100644 --- a/ltx_video/pipelines/pipeline_ltx_video.py +++ b/models/ltx_video/pipelines/pipeline_ltx_video.py @@ -24,22 +24,22 @@ from transformers import ( AutoTokenizer, ) -from ltx_video.models.autoencoders.causal_video_autoencoder import ( +from ..models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from ltx_video.models.autoencoders.vae_encode import ( +from ..models.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, vae_decode, vae_encode, ) -from ltx_video.models.transformers.symmetric_patchifier import Patchifier -from ltx_video.models.transformers.transformer3d import Transformer3DModel -from ltx_video.schedulers.rf import TimestepShifter -from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy -from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt -from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler -from ltx_video.models.autoencoders.vae_encode import ( +from ..models.transformers.symmetric_patchifier import Patchifier +from ..models.transformers.transformer3d import Transformer3DModel +from ..schedulers.rf import TimestepShifter +from ..utils.skip_layer_strategy import SkipLayerStrategy +from ..utils.prompt_enhance_utils import generate_cinematic_prompt +from ..models.autoencoders.latent_upsampler import LatentUpsampler +from ..models.autoencoders.vae_encode import ( un_normalize_latents, normalize_latents, ) @@ -120,6 +120,48 @@ ASPECT_RATIO_512_BIN = { "4.0": [1024.0, 256.0], } +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + + +def project( + v0: torch.Tensor, # [B, C, T, H, W] + v1: torch.Tensor, # [B, C, T, H, W] + ): + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() + v1 = torch.nn.functional.normalize(v1, dim=[-2, -1]) + v0_parallel = (v0 * v1).sum(dim=[-2, -1], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel.to(dtype), v0_orthogonal.to(dtype) + + +def adaptive_projected_guidance( + diff: torch.Tensor, # [B, C, T, H, W] + pred_cond: torch.Tensor, # [B, C, T, H, W] + momentum_buffer: MomentumBuffer = None, + eta: float = 0.0, + norm_threshold: float = 55, + ): + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=[-2, -1], keepdim=True) + print(f"diff_norm: {diff_norm}") + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + diff_parallel, diff_orthogonal = project(diff, pred_cond) + normalized_update = diff_orthogonal + eta * diff_parallel + return normalized_update # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -215,6 +257,7 @@ class ConditioningItem: media_item: torch.Tensor media_frame_number: int conditioning_strength: float + control_frames: bool = False media_x: Optional[int] = None media_y: Optional[int] = None @@ -796,6 +839,7 @@ class LTXVideoPipeline(DiffusionPipeline): text_encoder_max_tokens: int = 256, stochastic_sampling: bool = False, media_items: Optional[torch.Tensor] = None, + tone_map_compression_ratio: float = 0.0, strength: Optional[float] = 1.0, skip_initial_inference_steps: int = 0, skip_final_inference_steps: int = 0, @@ -803,6 +847,7 @@ class LTXVideoPipeline(DiffusionPipeline): pass_no: int = -1, ltxv_model = None, callback=None, + apg_switch = 0, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -876,6 +921,8 @@ class LTXVideoPipeline(DiffusionPipeline): media_items ('torch.Tensor', *optional*): The input media item used for image-to-image / video-to-video. When provided, they will be noised according to 'strength' and then fully denoised. + tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0. + If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied. strength ('floaty', *optional* defaults to 1.0): The editing level in image-to-image / video-to-video. The provided input will be noised to this level. @@ -1077,7 +1124,10 @@ class LTXVideoPipeline(DiffusionPipeline): ) ) init_latents = latents.clone() # Used for image_cond_noise_update - + if conditioning_items is not None and len(conditioning_items) > 0 and not conditioning_items[0].control_frames and conditioning_items[0].media_frame_number == 0: + prefix_latent_frames = (conditioning_items[0].media_item.shape[2] - 1)// 8 + 1 + else: + prefix_latent_frames = 0 # pixel_coords = torch.cat([pixel_coords] * num_conds) orig_conditioning_mask = conditioning_mask if conditioning_mask is not None and is_video: @@ -1096,6 +1146,12 @@ class LTXVideoPipeline(DiffusionPipeline): ) cfg_star_rescale = True + if apg_switch != 0: + apg_momentum = -0.75 + apg_norm_threshold = 55 + text_momentumbuffer = MomentumBuffer(apg_momentum) + audio_momentumbuffer = MomentumBuffer(apg_momentum) + if callback != None: callback(-1, None, True, override_num_inference_steps = num_inference_steps, pass_no =pass_no) @@ -1186,22 +1242,30 @@ class LTXVideoPipeline(DiffusionPipeline): )[-2:] if do_classifier_free_guidance and guidance_scale[i] !=0 and guidance_scale[i] !=1 : noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2] - if cfg_star_rescale: - batch_size = noise_pred_text.shape[0] - positive_flat = noise_pred_text.view(batch_size, -1) - negative_flat = noise_pred_uncond.view(batch_size, -1) - dot_product = torch.sum( - positive_flat * negative_flat, dim=1, keepdim=True + if apg_switch != 0: + noise_pred = noise_pred_text + (guidance_scale[i] - 1) * adaptive_projected_guidance(noise_pred_text - noise_pred_uncond, + noise_pred_text, + momentum_buffer=text_momentumbuffer, + norm_threshold=apg_norm_threshold) + + else: + if cfg_star_rescale: + batch_size = noise_pred_text.shape[0] + + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + dot_product = torch.sum( + positive_flat * negative_flat, dim=1, keepdim=True + ) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + alpha = dot_product / squared_norm + noise_pred_uncond = alpha * noise_pred_uncond + + + noise_pred = noise_pred_uncond + guidance_scale[i] * ( + noise_pred_text - noise_pred_uncond ) - squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 - alpha = dot_product / squared_norm - noise_pred_uncond = alpha * noise_pred_uncond - - - noise_pred = noise_pred_uncond + guidance_scale[i] * ( - noise_pred_text - noise_pred_uncond - ) elif do_spatio_temporal_guidance: noise_pred = noise_pred_text if do_spatio_temporal_guidance: @@ -1242,7 +1306,7 @@ class LTXVideoPipeline(DiffusionPipeline): if callback is not None: # callback(i, None, False, pass_no =pass_no) - preview_latents= latents.squeeze(0).transpose(0, 1) + preview_latents= latents[:, num_cond_latents:].squeeze(0).transpose(0, 1) preview_latents= preview_latents.reshape(preview_latents.shape[0], latent_num_frames, latent_height, latent_width) callback(i, preview_latents, False, pass_no =pass_no) preview_latents = None @@ -1285,8 +1349,9 @@ class LTXVideoPipeline(DiffusionPipeline): ) else: decode_timestep = None - torch.save(latents, "lala.pt") + # torch.save(latents, "lala.pt") # latents = torch.load("lala.pt") + latents = self.tone_map_latents(latents, tone_map_compression_ratio, start = prefix_latent_frames) image = vae_decode( latents, self.vae, @@ -1306,6 +1371,57 @@ class LTXVideoPipeline(DiffusionPipeline): return image + @staticmethod + def tone_map_latents( + latents: torch.Tensor, + compression: float, + start: int = 0 + ) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range + in a perceptually smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs + during generation, especially when controlling dynamic behavior with a `compression` factor. + + Parameters: + ---------- + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + ------- + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + if compression ==0: + return latents + if not (0 <= compression <= 1): + raise ValueError("Compression must be in the range [0, 1]") + + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + # DeepBeepMeep special touch to allow a smooth transition with tone mapping + if start > 0: + gradient_tensor = torch.linspace(0, 1, latents.shape[2],dtype= sigmoid_term.dtype, device=sigmoid_term.device) + gradient_tensor = gradient_tensor ** 0.5 + gradient_tensor = gradient_tensor[ None, None, :, None, None ] + sigmoid_term *= gradient_tensor + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + + filtered = latents * scales + return filtered + def denoising_step( self, latents: torch.Tensor, @@ -1405,18 +1521,18 @@ class LTXVideoPipeline(DiffusionPipeline): media_item = conditioning_item.media_item media_frame_number = conditioning_item.media_frame_number strength = conditioning_item.conditioning_strength + control_frames = conditioning_item.control_frames assert media_item.ndim == 5 # (b, c, f, h, w) b, c, n_frames, h, w = media_item.shape assert ( height == h and width == w ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - assert n_frames % 8 == 1 - assert ( - media_frame_number >= 0 - and media_frame_number + n_frames <= num_frames - ) + # assert n_frames % 8 == 1 + # assert ( + # media_frame_number >= 0 + # and media_frame_number + n_frames <= num_frames + # ) - # Encode the provided conditioning media item media_item_latents = vae_encode( media_item.to(dtype=self.vae.dtype, device=self.vae.device), self.vae, @@ -1424,7 +1540,33 @@ class LTXVideoPipeline(DiffusionPipeline): ).to(dtype=init_latents.dtype) # Handle the different conditioning cases - if media_frame_number == 0: + if control_frames: + #control frames sequence is assumed to start one frame before the actual location so that we can properly insert the prefix latent + if media_frame_number > 0: + media_frame_number = media_frame_number -1 + media_item_latents, media_latent_coords = self.patchifier.patchify( + latents=media_item_latents + ) + media_pixel_coords = latent_to_pixel_coords( + media_latent_coords, + self.vae, + causal_fix=self.transformer.config.causal_temporal_positioning, + ) + + media_conditioning_mask = torch.full( + media_item_latents.shape[:2], + strength, + dtype=torch.float32, + device=init_latents.device, + ) + + # Update the frame numbers to match the target frame number + media_pixel_coords[:, 0] += media_frame_number + extra_conditioning_num_latents += media_item_latents.shape[1] + extra_conditioning_latents.append(media_item_latents) + extra_conditioning_pixel_coords.append(media_pixel_coords) + extra_conditioning_mask.append(media_conditioning_mask) + elif media_frame_number == 0: # Get the target spatial position of the latent conditioning item media_item_latents, l_x, l_y = self._get_latent_spatial_position( media_item_latents, diff --git a/ltx_video/pipelines/__init__.py b/models/ltx_video/schedulers/__init__.py similarity index 100% rename from ltx_video/pipelines/__init__.py rename to models/ltx_video/schedulers/__init__.py diff --git a/ltx_video/schedulers/rf.py b/models/ltx_video/schedulers/rf.py similarity index 99% rename from ltx_video/schedulers/rf.py rename to models/ltx_video/schedulers/rf.py index 2cf99da..bced26a 100644 --- a/ltx_video/schedulers/rf.py +++ b/models/ltx_video/schedulers/rf.py @@ -14,9 +14,9 @@ from torch import Tensor from safetensors import safe_open -from ltx_video.utils.torch_utils import append_dims +from ..utils.torch_utils import append_dims -from ltx_video.utils.diffusers_config_mapping import ( +from ..utils.diffusers_config_mapping import ( diffusers_and_ours_config_mapping, make_hashable_key, ) diff --git a/ltx_video/schedulers/__init__.py b/models/ltx_video/utils/__init__.py similarity index 100% rename from ltx_video/schedulers/__init__.py rename to models/ltx_video/utils/__init__.py diff --git a/ltx_video/utils/diffusers_config_mapping.py b/models/ltx_video/utils/diffusers_config_mapping.py similarity index 100% rename from ltx_video/utils/diffusers_config_mapping.py rename to models/ltx_video/utils/diffusers_config_mapping.py diff --git a/ltx_video/utils/prompt_enhance_utils.py b/models/ltx_video/utils/prompt_enhance_utils.py similarity index 76% rename from ltx_video/utils/prompt_enhance_utils.py rename to models/ltx_video/utils/prompt_enhance_utils.py index 449fa74..dbfe0c1 100644 --- a/ltx_video/utils/prompt_enhance_utils.py +++ b/models/ltx_video/utils/prompt_enhance_utils.py @@ -23,6 +23,23 @@ Note any changes or sudden events Do not exceed the 150 word limit! Output the enhanced prompt only. """ +T2I_VISUAL_PROMPT = """You are an expert visual artist and photographer with award-winning compositions. When writing prompts based on the user input, focus on detailed, precise descriptions of visual elements and composition. +Include specific poses, appearances, framing, and environmental details - all in a single flowing paragraph. +Start directly with the main subject, and keep descriptions literal and precise. +Think like a photographer describing the perfect shot. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main subject and pose in a single sentence +Add specific details about expressions and positioning +Describe character/object appearances precisely +Include background and environment details +Specify framing, composition and perspective +Describe lighting, colors, and mood +Note any atmospheric or stylistic elements +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. @@ -43,6 +60,24 @@ Do not exceed the 150 word limit! Output the enhanced prompt only. """ +I2I_VISUAL_PROMPT = """You are an expert visual artist and photographer with award-winning compositions. When writing prompts based on the user input, focus on detailed, precise descriptions of visual elements and composition. +Include specific poses, appearances, framing, and environmental details - all in a single flowing paragraph. +Start directly with the main subject, and keep descriptions literal and precise. +Think like a photographer describing the perfect shot. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main subject and pose in a single sentence +Add specific details about expressions and positioning +Describe character/object appearances precisely +Include background and environment details +Specify framing, composition and perspective +Describe lighting, colors, and mood +Note any atmospheric or stylistic elements +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + def tensor_to_pil(tensor): # Ensure tensor is in range [-1, 1] @@ -68,6 +103,7 @@ def generate_cinematic_prompt( prompt_enhancer_tokenizer, prompt: Union[str, List[str]], images: Optional[List] = None, + video_prompt= True, max_new_tokens: int = 256, ) -> List[str]: prompts = [prompt] if isinstance(prompt, str) else prompt @@ -78,7 +114,7 @@ def generate_cinematic_prompt( prompt_enhancer_tokenizer, prompts, max_new_tokens, - T2V_CINEMATIC_PROMPT, + T2V_CINEMATIC_PROMPT if video_prompt else T2I_VISUAL_PROMPT, ) else: @@ -90,7 +126,7 @@ def generate_cinematic_prompt( prompts, images, max_new_tokens, - I2V_CINEMATIC_PROMPT, + I2V_CINEMATIC_PROMPT if video_prompt else I2I_VISUAL_PROMPT, ) return prompts diff --git a/ltx_video/utils/skip_layer_strategy.py b/models/ltx_video/utils/skip_layer_strategy.py similarity index 100% rename from ltx_video/utils/skip_layer_strategy.py rename to models/ltx_video/utils/skip_layer_strategy.py diff --git a/ltx_video/utils/torch_utils.py b/models/ltx_video/utils/torch_utils.py similarity index 100% rename from ltx_video/utils/torch_utils.py rename to models/ltx_video/utils/torch_utils.py diff --git a/models/qwen/autoencoder_kl_qwenimage.py b/models/qwen/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000..d144284 --- /dev/null +++ b/models/qwen/autoencoder_kl_qwenimage.py @@ -0,0 +1,1096 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - arXiv: https://arxiv.org/abs/2503.20314 + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + @staticmethod + def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision): + + # VAE Tiling + if vae_config == 0: + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + use_tiling = False + tile_sample_min_width = 256 + + if use_vae_config == 1: + use_tiling = False + elif use_vae_config == 2: + use_tiling = True + tile_sample_min_width = 256 + + return (use_tiling, tile_sample_min_width) + + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py new file mode 100644 index 0000000..07bdbd4 --- /dev/null +++ b/models/qwen/pipeline_qwenimage.py @@ -0,0 +1,897 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mmgp import offload +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch, json +import math +from diffusers.image_processor import VaeImageProcessor +from .transformer_qwenimage import QwenImageTransformer2DModel + +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image +from shared.utils.utils import calculate_new_dimensions + +XLA_AVAILABLE = False + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/QwenImage-20B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class QwenImagePipeline(): #DiffusionPipeline + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + vae, + text_encoder, + tokenizer, + transformer, + processor, + ): + + self.vae=vae + self.text_encoder=text_encoder + self.tokenizer=tokenizer + self.transformer=transformer + self.processor = processor + + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + if processor is not None: + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + else: + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + if self.processor is not None and image is not None: + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + else: + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + image = None, + callback=None, + pipeline=None, + loras_slists=None, + joint_pass= True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + kwargs = {'pipeline': pipeline, 'callback': callback} + if callback != None: + callback(-1, None, True) + + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = "cuda" + + prompt_image = None + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(image) + aspect_ratio = image_width / image_height + if False : + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + ref_height, ref_width = 1568, 672 + if height * width < ref_height * ref_width: ref_height , ref_width = height , width + if image_height * image_width > ref_height * ref_width: + image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + + image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) + prompt_image = image + image = self.image_processor.preprocess(image, image_height, image_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + dtype = torch.bfloat16 + prompt_embeds = prompt_embeds.to(dtype) + if do_true_cfg: + negative_prompt_embeds = negative_prompt_embeds.to(dtype) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image is not None: + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + ] + ] * batch_size + else: + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + updated_num_steps= len(timesteps) + if callback != None: + from shared.utils.loras_mutipliers import update_loras_slists + update_loras_slists(self.transformer, loras_slists, updated_num_steps) + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[negative_prompt_embeds_mask], + encoder_hidden_states_list=[negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if neg_noise_pred == None: return None + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + + if do_true_cfg: + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + if comb_pred == None: return None + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + neg_noise_pred = None + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback is not None: + # preview = unpack_latent(img).transpose(0,1) + callback(i, None, False) + + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + + + return image diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py new file mode 100644 index 0000000..c6004e1 --- /dev/null +++ b/models/qwen/qwen_handler.py @@ -0,0 +1,87 @@ +import torch + +def get_qwen_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") + return text_encoder_filename + +class family_handler(): + @staticmethod + def query_model_def(base_model_type, model_def): + model_def_output = { + "image_outputs" : True, + "sample_solvers":[ + ("Default", "default"), + ("Lightning", "lightning")], + "guidance_max_phases" : 1, + } + + + return model_def_output + + @staticmethod + def query_supported_types(): + return ["qwen_image_20B", "qwen_image_edit_20B"] + + @staticmethod + def query_family_maps(): + return {}, {} + + @staticmethod + def query_model_family(): + return "qwen" + + @staticmethod + def query_family_infos(): + return {"qwen":(40, "Qwen")} + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = get_qwen_text_encoder_filename(text_encoder_quantization) + return { + "repoId" : "DeepBeepMeep/Qwen_image", + "sourceFolderList" : ["", "Qwen2.5-VL-7B-Instruct"], + "fileList" : [ ["qwen_vae.safetensors", "qwen_vae_config.json"], ["merges.txt", "tokenizer_config.json", "config.json", "vocab.json", "video_preprocessor_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) ] + } + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from .qwen_main import model_factory + from mmgp import offload + + pipe_processor = model_factory( + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= get_qwen_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"tokenizer" : pipe_processor.tokenizer, "transformer" : pipe_processor.transformer, "text_encoder" : pipe_processor.text_encoder, "vae" : pipe_processor.vae} + + return pipe_processor, pipe + + + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if ui_defaults.get("sample_solver", "") == "": + ui_defaults["sample_solver"] = "default" + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "guidance_scale": 4, + "sample_solver": "default", + }) + if model_def.get("reference_image", False): + ui_defaults.update({ + "video_prompt_type": "KI", + }) + diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py new file mode 100644 index 0000000..156eeed --- /dev/null +++ b/models/qwen/qwen_main.py @@ -0,0 +1,206 @@ + +from mmgp import offload +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch, json, os +import math + +from diffusers.image_processor import VaeImageProcessor +from .transformer_qwenimage import QwenImageTransformer2DModel + +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer, Qwen2VLProcessor +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from diffusers import FlowMatchEulerDiscreteScheduler +from .pipeline_qwenimage import QwenImagePipeline +from PIL import Image +from shared.utils.utils import calculate_new_dimensions + +def stitch_images(img1, img2): + # Resize img2 to match img1's height + width1, height1 = img1.size + width2, height2 = img2.size + new_width2 = int(width2 * height1 / height2) + img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) + + stitched = Image.new('RGB', (width1 + new_width2, height1)) + stitched.paste(img1, (0, 0)) + stitched.paste(img2_resized, (width1, 0)) + return stitched + +class model_factory(): + def __init__( + self, + checkpoint_dir, + model_filename = None, + model_type = None, + model_def = None, + base_model_type = None, + text_encoder_filename = None, + quantizeTransformer = False, + save_quantized = False, + dtype = torch.bfloat16, + VAE_dtype = torch.float32, + mixed_precision_transformer = False + ): + + + transformer_filename = model_filename[0] + processor = None + tokenizer = None + if base_model_type == "qwen_image_edit_20B": + processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) + tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) + + + base_config_file = "configs/qwen_image_20B.json" + with open(base_config_file, 'r', encoding='utf-8') as f: + transformer_config = json.load(f) + transformer_config.pop("_diffusers_version") + transformer_config.pop("_class_name") + transformer_config.pop("pooled_projection_dim") + + from accelerate import init_empty_weights + with init_empty_weights(): + transformer = QwenImageTransformer2DModel(**transformer_config) + source = model_def.get("source", None) + + if source is not None: + offload.load_model_data(transformer, source) + else: + offload.load_model_data(transformer, transformer_filename) + # transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json") + + if not source is None: + from wgp import save_model + save_model(transformer, model_type, dtype, None) + + if save_quantized: + from wgp import save_quantized_model + save_quantized_model(transformer, model_type, model_filename[0], dtype, base_config_file) + + text_encoder = offload.fast_load_transformers_model(text_encoder_filename, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json")) + # text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2) + # text_encoder.to(torch.float16) + # offload.save_model(text_encoder, "text_encoder_quanto_fp16.safetensors", do_quantize= True) + + vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json")) + + self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer, processor) + self.vae=vae + self.text_encoder=text_encoder + self.tokenizer=tokenizer + self.transformer=transformer + self.processor = processor + + def generate( + self, + seed: int | None = None, + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + n_prompt = None, + sampling_steps: int = 20, + input_ref_images = None, + width= 832, + height=480, + guide_scale: float = 4, + fit_into_canvas = None, + callback = None, + loras_slists = None, + batch_size = 1, + video_prompt_type = "", + VAE_tile_size = None, + joint_pass = True, + sample_solver='default', + **bbargs + ): + # Generate with different aspect ratios + aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472) + } + + + if sample_solver =='lightning': + scheduler_config = { + "base_image_seq_len": 256, + "base_shift": math.log(3), # We use shift=3 in distillation + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": math.log(3), # We use shift=3 in distillation + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, # set shift_terminal to None + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, + } + else: + scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": 0.9, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.02, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False + } + + self.scheduler=FlowMatchEulerDiscreteScheduler(**scheduler_config) + self.pipeline.scheduler = self.scheduler + if VAE_tile_size is not None: + self.vae.use_tiling = VAE_tile_size[0] + self.vae.tile_latent_min_height = VAE_tile_size[1] + self.vae.tile_latent_min_width = VAE_tile_size[1] + + + self.vae.enable_slicing() + # width, height = aspect_ratios["16:9"] + + if n_prompt is None or len(n_prompt) == 0: + n_prompt= "text, watermark, copyright, blurry, low resolution" + + if input_ref_images is not None: + # image stiching method + stiched = input_ref_images[0] + if "K" in video_prompt_type : + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + + for new_img in input_ref_images[1:]: + stiched = stitch_images(stiched, new_img) + input_ref_images = [stiched] + + image = self.pipeline( + prompt=input_prompt, + negative_prompt=n_prompt, + image = input_ref_images, + width=width, + height=height, + num_inference_steps=sampling_steps, + num_images_per_prompt = batch_size, + true_cfg_scale=guide_scale, + callback = callback, + pipeline=self, + loras_slists=loras_slists, + joint_pass = joint_pass, + generator=torch.Generator(device="cuda").manual_seed(seed) + ) + if image is None: return None + return image.transpose(0, 1) + diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py new file mode 100644 index 0000000..6d90806 --- /dev/null +++ b/models/qwen/transformer_qwenimage.py @@ -0,0 +1,635 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm +from shared.attention import pay_attention +import functools + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + dtype = joint_query.dtype + qkv_list = [joint_query, joint_key, joint_value ] + joint_query = joint_key = joint_value = None + joint_hidden_states = pay_attention(qkv_list) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QwenImageTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + + def preprocess_loras(self, model_type, sd): + + first = next(iter(sd), None) + if first == None: + return sd + + new_sd = {} + for k,v in sd.items(): + k = k.replace(".lora.", ".lora_") + new_sd[k] = v + sd = new_sd + + prefix_list = ["lora_unet_transformer_blocks"] + for prefix in prefix_list: + if first.startswith(prefix): + repl_list = ["attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"] + src_list = ["_" + k + "_" for k in repl_list] + tgt_list = ["." + k + "." for k in repl_list] + src_list2 = ["_0_", "_0.", "_1.", "_2."] + tgt_list2 = [".0.", ".0.", ".1.", ".2."] + new_sd = {} + for k,v in sd.items(): + k = "diffusion_model.transformer_blocks." + k[len(prefix)+1:] + for s,t in zip(src_list, tgt_list): + k = k.replace(s,t) + for s,t in zip(src_list2, tgt_list2): + k = k.replace(s,t) + new_sd[k] = v + sd = new_sd + return sd + + prefix_list = ["transformer_blocks"] + for prefix in prefix_list: + if first.startswith(prefix): + new_sd = {} + for k,v in sd.items(): + if k.startswith(prefix): + k = "diffusion_model." + k + new_sd[k] = v + sd = new_sd + return sd + return sd + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.guidance_embeds = guidance_embeds + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states_list = None, + encoder_hidden_states_mask_list = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens_list = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + callback= None, + pipeline =None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + hidden_states_list = [hidden_states if i == 0 else hidden_states.clone() for i, _ in enumerate(encoder_hidden_states_list)] + + new_encoder_hidden_states_list = [] + for encoder_hidden_states in encoder_hidden_states_list: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + new_encoder_hidden_states_list.append(encoder_hidden_states) + encoder_hidden_states_list = new_encoder_hidden_states_list + new_encoder_hidden_states_list = encoder_hidden_states = None + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb_list = [ self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for txt_seq_lens in txt_seq_lens_list] + + hidden_states = None + + for index_block, block in enumerate(self.transformer_blocks): + if callback != None: + callback(-1, None, False, True) + if pipeline._interrupt: + return [None] * len(hidden_states_list) + for hidden_states, encoder_hidden_states, encoder_hidden_states_mask, image_rotary_emb in zip(hidden_states_list, encoder_hidden_states_list, encoder_hidden_states_mask_list, image_rotary_emb_list): + encoder_hidden_states[...], hidden_states[...] = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + output_list = [] + for i in range(len(hidden_states_list)): + hidden_states = self.norm_out(hidden_states_list[i], temb) + hidden_states_list[i] = None + output_list.append(self.proj_out(hidden_states)) + + return output_list diff --git a/wan/__init__.py b/models/wan/__init__.py similarity index 50% rename from wan/__init__.py rename to models/wan/__init__.py index 1688425..fe3be71 100644 --- a/wan/__init__.py +++ b/models/wan/__init__.py @@ -1,3 +1,4 @@ from . import configs, distributed, modules from .any2video import WanAny2V -from .diffusion_forcing import DTT2V \ No newline at end of file +from .diffusion_forcing import DTT2V +from . import wan_handler, df_handler diff --git a/wan/any2video.py b/models/wan/any2video.py similarity index 62% rename from wan/any2video.py rename to models/wan/any2video.py index 67b1564..abe5249 100644 --- a/wan/any2video.py +++ b/models/wan/any2video.py @@ -19,18 +19,21 @@ from PIL import Image import torchvision.transforms.functional as TF import torch.nn.functional as F from .distributed.fsdp import shard_model -from .modules.model import WanModel +from .modules.model import WanModel, clear_caches from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE +from .modules.vae2_2 import Wan2_2_VAE + from .modules.clip import CLIPModel -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, +from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.vace_preprocessor import VaceVideoProcessor -from wan.utils.basic_flowmatch import FlowMatchScheduler -from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions -from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance +from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed +from shared.utils.vace_preprocessor import VaceVideoProcessor +from shared.utils.basic_flowmatch import FlowMatchScheduler +from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor +from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask +from mmgp import safetensors2 def optimized_scale(positive_flat, negative_flat): @@ -77,59 +80,97 @@ class WanAny2V: self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype self.model_def = model_def - self.image_outputs = model_def.get("image_outputs", False) + self.model2 = None + self.transformer_switch = model_def.get("URLs2", None) is not None self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + tokenizer_path=os.path.join(checkpoint_dir, "umt5-xxl"), shard_fn= None) - if hasattr(config, "clip_checkpoint"): + # base_model_type = "i2v2_2" + if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]: self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, checkpoint_path=os.path.join(checkpoint_dir , config.clip_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer)) + tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) - self.vae_stride = config.vae_stride + + if base_model_type in ["ti2v_2_2"]: + self.vae_stride = (4, 16, 16) + vae_checkpoint = "Wan2.2_VAE.safetensors" + vae = Wan2_2_VAE + else: + self.vae_stride = config.vae_stride + vae_checkpoint = "Wan2.1_VAE.safetensors" + vae = WanVAE self.patch_size = config.patch_size - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, - device=self.device) + self.vae = vae( + vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, + device="cpu") + self.vae.device = self.device - # xmodel_filename = "c:/ml/multitalk/multitalk.safetensors" - # config_filename= "configs/multitalk.json" + # config_filename= "configs/t2v_1.3B.json" # import json # with open(config_filename, 'r', encoding='utf-8') as f: # config = json.load(f) - # from mmgp import safetensors2 # sd = safetensors2.torch_load_file(xmodel_filename) - # model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors" + # model_filename = "c:/temp/wan2.2i2v/low/diffusion_pytorch_model-00001-of-00006.safetensors" base_config_file = f"configs/{base_model_type}.json" forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + + source = model_def.get("source", None) + module_source = model_def.get("module_source", None) + if module_source is not None: + model_filename = [] + model_filename + model_filename[1] = module_source + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + elif source is not None: + self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + elif self.transformer_switch: + shared_modules= {} + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + shared_modules = None + else: + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + # self.model = offload.load_model_data(self.model, xmodel_filename ) # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) - # offload.save_model(self.model, "flf2v_720p.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "flf2v_quanto_int8_fp16_720p.safetensors", do_quantize= True, config_file_path=base_config_file) - # offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd) + if self.model2 is not None: + self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) + offload.change_dtype(self.model2, dtype, True) - # offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd) + # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file) + # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) self.model.eval().requires_grad_(False) + if self.model2 is not None: + self.model2.eval().requires_grad_(False) + if module_source is not None: + from wgp import save_model + from mmgp.safetensors2 import torch_load_file + filter = list(torch_load_file(module_source)) + save_model(self.model, model_type, dtype, None, is_module=True, filter=filter) + elif not source is None: + from wgp import save_model + save_model(self.model, model_type, dtype, None) + if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) - + if self.model2 is not None: + save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) self.sample_neg_prompt = config.sample_neg_prompt if self.model.config.get("vace_in_dim", None) != None: @@ -142,7 +183,8 @@ class WanAny2V: seq_len=32760, keep_last=True) - self.adapt_vace_model() + self.adapt_vace_model(self.model) + if self.model2 is not None: self.adapt_vace_model(self.model2) self.num_timesteps = 1000 self.use_timestep_transform = True @@ -219,7 +261,7 @@ class WanAny2V: return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): - from wan.utils.utils import save_image + from shared.utils.utils import save_image ref_width, ref_height = ref_img.size if (ref_height, ref_width) == image_size and outpainting_dims == None: ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) @@ -344,7 +386,7 @@ class WanAny2V: input_frames= None, input_masks = None, input_ref_images = None, - input_video=None, + input_video = None, image_start = None, image_end = None, denoising_strength = 1.0, @@ -359,6 +401,12 @@ class WanAny2V: sample_solver='unipc', sampling_steps=50, guide_scale=5.0, + guide2_scale = 5.0, + guide3_scale = 5.0, + switch_threshold = 0, + switch2_threshold = 0, + guide_phases= 1 , + model_switch_phase = 1, n_prompt="", seed=-1, callback = None, @@ -380,6 +428,7 @@ class WanAny2V: conditioning_latents_size = 0, keep_frames_parsed = [], model_type = None, + model_mode = None, loras_slists = None, NAG_scale = 0, NAG_tau = 3.5, @@ -387,6 +436,14 @@ class WanAny2V: offloadobj = None, apg_switch = False, speakers_bboxes = None, + color_correction_strength = 1, + prefix_frames_count = 0, + image_mode = 0, + window_no = 0, + set_header_text = None, + pre_video_frame = None, + video_prompt_type= "", + original_input_ref_images = [], **bbargs ): @@ -420,15 +477,15 @@ class WanAny2V: sigmas=sampling_sigmas) else: raise NotImplementedError(f"Unsupported Scheduler {sample_solver}") + original_timesteps = timesteps seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - + image_outputs = image_mode == 1 kwargs = {'pipeline': self, 'callback': callback} - + color_reference_frame = None if self._interrupt: return None - # Text Encoder if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -452,114 +509,157 @@ class WanAny2V: # if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0) if self._interrupt: return None - vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"] + vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] - multitalk = model_type in ["multitalk", "vace_multitalk_14B"] - + multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] + infinitetalk = model_type in ["infinitetalk"] + standin = model_type in ["standin", "vace_standin_14B"] + recam = model_type in ["recam_1.3B"] + ti2v = model_type in ["ti2v_2_2"] + start_step_no = 0 ref_images_count = 0 trim_frames = 0 extended_overlapped_latents = None - - # image2video + no_noise_latents_injection = infinitetalk + timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 - if image_start != None: + # image2video + if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False - if input_frames != None: - _ , preframes_count, height, width = input_frames.shape - lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - clip_context = self.clip.visual([input_frames[:, -1:]]) if model_type != "flf2v_720p" else self.clip.visual([input_frames[:, -1:], input_frames[:, -1:]]) - input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype) - enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width), - device=self.device, dtype= self.VAE_dtype)], - dim = 1).to(self.device) - input_frames = None + if image_start is None: + if infinitetalk: + if input_frames is not None: + image_ref = input_frames[:, -1] + if input_video is None: input_video = input_frames[:, -1:] + new_shot = "Q" in video_prompt_type + else: + if pre_video_frame is None: + new_shot = True + else: + if input_ref_images is None: + input_ref_images, new_shot = [pre_video_frame], False + else: + input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type + if input_ref_images is None: raise Exception("Missing Reference Image") + new_shot = new_shot and window_no <= len(input_ref_images) + image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) + if new_shot: + input_video = image_ref.unsqueeze(1) + else: + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + _ , preframes_count, height, width = input_video.shape + input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) + if infinitetalk: + image_for_clip = image_ref.to(input_video) + control_pre_frames_count = 1 + control_video = image_for_clip.unsqueeze(1) + else: + image_for_clip = input_video[:, -1] + control_pre_frames_count = preframes_count + control_video = input_video + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + if hasattr(self, "clip"): + clip_image_size = self.clip.model.image_size + clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :] + clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ]) + clip_image = None + else: + clip_context = None + enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width), + device=self.device, dtype= self.VAE_dtype)], + dim = 1).to(self.device) + color_reference_frame = image_for_clip.unsqueeze(1).clone() else: - preframes_count = 1 - image_start = TF.to_tensor(image_start) - any_end_frame = image_end != None + preframes_count = control_pre_frames_count = 1 + any_end_frame = image_end is not None add_frames_for_end_image = any_end_frame and model_type == "i2v" if any_end_frame: - image_end = TF.to_tensor(image_end) if add_frames_for_end_image: frame_num +=1 lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) trim_frames = 1 - h, w = image_start.shape[1:] + height, width = image_start.shape[1:] - h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - width, height = w, h - lat_h = round( - h // self.vae_stride[1] // + height // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( - w // self.vae_stride[2] // + width // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) - h = lat_h * self.vae_stride[1] - w = lat_w * self.vae_stride[2] - clip_image_size = self.clip.model.image_size - img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - if image_end!= None: - img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end != None else image_start[:, None, :, :]]) + height = lat_h * self.vae_stride[1] + width = lat_w * self.vae_stride[2] + image_start_frame = image_start.unsqueeze(1).to(self.device) + color_reference_frame = image_start_frame.clone() + if image_end is not None: + img_end_frame = image_end.unsqueeze(1).to(self.device) + + if hasattr(self, "clip"): + clip_image_size = self.clip.model.image_size + image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) + if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) + if model_type == "flf2v_720p": + clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([image_start[:, None, :, :]]) else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) + clip_context = None if any_end_frame: enc= torch.concat([ - img_interpolated, - torch.zeros( (3, frame_num-2, h, w), device=self.device, dtype= self.VAE_dtype), - img_interpolated2, + image_start_frame, + torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype), + img_end_frame, ], dim=1).to(self.device) else: enc= torch.concat([ - img_interpolated, - torch.zeros( (3, frame_num-1, h, w), device=self.device, dtype= self.VAE_dtype) + image_start_frame, + torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype) ], dim=1).to(self.device) - image_start = image_end = img_interpolated = img_interpolated2 = None + image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: - msk[:, preframes_count: -1] = 0 + msk[:, control_pre_frames_count: -1] = 0 if add_frames_for_end_image: msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) else: msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) else: - msk[:, preframes_count:] = 0 + msk[:, control_pre_frames_count:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] - lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] + y = torch.concat([msk, lat_y]) overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4) - if overlapped_latents != None: + # if overlapped_latents != None: + if overlapped_latents_frames_num > 0: # disabled because looks worse if False and overlapped_latents_frames_num > 1: lat_y[:, :, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:] + if infinitetalk: + lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - y = torch.concat([msk, lat_y]) - lat_y = None - kwargs.update({'clip_fea': clip_context, 'y': y}) + # if control_pre_frames_count != pre_frames_count: + + lat_y = input_video = None + kwargs.update({ 'y': y}) + if not clip_context is None: + kwargs.update({'clip_fea': clip_context}) # Recam Master - if target_camera != None: - width = input_video.shape[2] - height = input_video.shape[1] + if recam: + # should be be in fact in input_frames since it is control video not a video to be extended + target_camera = model_mode + height,width = input_video.shape[-2:] input_video = input_video.to(dtype=self.dtype , device=self.device) - input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device) + source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) del input_video # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding + from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) kwargs['cam_emb'] = cam_emb @@ -567,19 +667,20 @@ class WanAny2V: # Video 2 Video if denoising_strength < 1. and input_frames != None: height, width = input_frames.shape[-2:] - source_latents = self.vae.encode([input_frames])[0] + source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) injection_denoising_step = 0 inject_from_start = False if input_frames != None and denoising_strength < 1 : + color_reference_frame = input_frames[:, -1:].clone() if overlapped_latents != None: - overlapped_latents_frames_num = overlapped_latents.shape[1] + overlapped_latents_frames_num = overlapped_latents.shape[2] overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: overlapped_latents_frames_num = overlapped_frames_num = 0 - if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] + if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) latent_keep_frames = [] - if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0: + if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0: inject_from_start = True if len(keep_frames_parsed) >0 : if overlapped_frames_num > 0: keep_frames_parsed = [True] * overlapped_frames_num + keep_frames_parsed @@ -588,6 +689,7 @@ class WanAny2V: latent_keep_frames.append(all(keep_frames_parsed[i:i+4])) else: timesteps = timesteps[injection_denoising_step:] + start_step_no = injection_denoising_step if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] injection_denoising_step = 0 @@ -601,6 +703,14 @@ class WanAny2V: ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 trim_frames = input_ref_images.shape[1] + if ti2v: + if input_video is None: + height, width = (height // 32) * 32, (width // 32) * 32 + else: + height, width = input_video.shape[-2:] + source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) + timestep_injection = True + # Vace if vace : # vace context encode @@ -611,6 +721,7 @@ class WanAny2V: z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) if self.background_mask != None: + color_reference_frame = input_ref_images[0][0].clone() zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) mbg = self.vace_encode_masks(self.background_mask, None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): @@ -626,6 +737,8 @@ class WanAny2V: if overlapped_latents != None : overlapped_latents_size = overlapped_latents.shape[2] extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) + if prefix_frames_count > 0: + color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() target_shape = list(z0[0].shape) target_shape[0] = int(target_shape[0] / 2) @@ -637,7 +750,7 @@ class WanAny2V: target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) if multitalk and audio_proj != None: - from wan.multitalk.multitalk import get_target_masks + from .multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None @@ -660,17 +773,37 @@ class WanAny2V: kwargs["freqs"] = freqs + #Standin + if standin: + from preprocessing.face_preprocessor import FaceProcessor + standin_ref_pos = 1 if "K" in video_prompt_type else 0 + if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image") + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + image_ref.save("si.png") + # face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2") + face_processor = FaceProcessor() + standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"]) + face_processor = None + gc.collect() + torch.cuda.empty_cache() + standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) + standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) + kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) + + # Steps Skipping - cache_type = self.model.enable_cache - if cache_type != None: + skip_steps_cache = self.model.cache + if skip_steps_cache != None: + cache_type = skip_steps_cache.cache_type x_count = 3 if phantom or fantasy or multitalk else 2 - self.model.previous_residual = [None] * x_count + skip_steps_cache.previous_residual = [None] * x_count if cache_type == "tea": - self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) + self.model.compute_teacache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier) else: - self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) - self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count - self.model.one_for_all = x_count > 2 + self.model.compute_magcache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier) + skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count + skip_steps_cache.one_for_all = x_count > 2 if callback != None: callback(-1, None, True) @@ -682,10 +815,30 @@ class WanAny2V: # init denoising updated_num_steps= len(timesteps) - if callback != None: - from wan.utils.utils import update_loras_slists - update_loras_slists(self.model, loras_slists, updated_num_steps) - callback(-1, None, True, override_num_inference_steps = updated_num_steps) + + denoising_extra = "" + from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps + + phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) + if len(phases_description) > 0: set_header_text(phases_description) + guidance_switch_done = guidance_switch2_done = False + if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}" + def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_done, switch_threshold, trans, phase_no, denoising_extra): + if guide_phases >= phase_no and not guidance_switch_done and t <= switch_threshold: + if model_switch_phase == phase_no-1 and self.model2 is not None: trans = self.model2 + guide_scale, guidance_switch_done = new_guide_scale, True + denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}" + callback(step_no-1, denoising_extra = denoising_extra) + return guide_scale, guidance_switch_done, trans, denoising_extra + update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) + + def clear(): + clear_caches() + gc.collect() + torch.cuda.empty_cache() + return None if sample_scheduler != None: scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} @@ -696,12 +849,22 @@ class WanAny2V: apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - # self.image_outputs = False + + # denoising + trans = self.model for i, t in enumerate(tqdm(timesteps)): - offload.set_step_no_for_lora(self.model, i) + guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra) + guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra) + offload.set_step_no_for_lora(trans, i) timestep = torch.stack([t]) - kwargs.update({"t": timestep, "current_step": i}) + + if timestep_injection: + latents[:, :, :source_latents.shape[2]] = source_latents + timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) + timestep[:source_latents.shape[2]] = 0 + + kwargs.update({"t": timestep, "current_step": start_step_no + i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: @@ -709,29 +872,32 @@ class WanAny2V: noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: new_latents = latents.clone() - new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0) + new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] latents = new_latents new_latents = None else: - latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0) + latents = noise * sigma + (1 - sigma) * source_latents noise = None if extended_overlapped_latents != None: - latent_noise_factor = t / 1000 - latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor + if no_noise_latents_injection: + latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents + else: + latent_noise_factor = t / 1000 if vace: overlap_noise_factor = overlap_noise / 1000 for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!! + latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) else: latent_model_input = latents + any_guidance = guide_scale != 1 if phantom: gen_args = { "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + @@ -745,33 +911,42 @@ class WanAny2V: "audio_scale": [audio_scale, None, None ] } elif multitalk and audio_proj != None: - gen_args = { - "x" : [latent_model_input, latent_model_input, latent_model_input], - "context" : [context, context_null, context_null], - "multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]], - "multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None] - } + if guide_scale == 1: + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context" : [context, context], + "multitalk_audio": [audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]], + "multitalk_masks": [token_ref_target_masks, None] + } + any_guidance = audio_cfg_scale != 1 + else: + gen_args = { + "x" : [latent_model_input, latent_model_input, latent_model_input], + "context" : [context, context_null, context_null], + "multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]], + "multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None] + } else: gen_args = { "x" : [latent_model_input, latent_model_input], "context": [context, context_null] } - if joint_pass and guide_scale > 1: - ret_values = self.model( **gen_args , **kwargs) + if joint_pass and any_guidance: + ret_values = trans( **gen_args , **kwargs) if self._interrupt: - return None + return clear() else: - size = 1 if guide_scale == 1 else len(gen_args["x"]) + size = len(gen_args["x"]) if any_guidance else 1 ret_values = [None] * size for x_id in range(size): sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() } - ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0] + ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0] if self._interrupt: - return None + return clear() sub_gen_args = None - if guide_scale == 1: - noise_pred = ret_values[0] + if not any_guidance: + noise_pred = ret_values[0] elif phantom: guide_scale_img= 5.0 guide_scale_text= guide_scale #7.5 @@ -782,20 +957,33 @@ class WanAny2V: noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = ret_values noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) noise_pred_noaudio = None - elif multitalk: - noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values + elif multitalk and audio_proj != None: if apg_switch != 0: - noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text, - noise_pred_cond, - momentum_buffer=text_momentumbuffer, - norm_threshold=apg_norm_threshold) \ - + (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond, - noise_pred_cond, - momentum_buffer=audio_momentumbuffer, - norm_threshold=apg_norm_threshold) + if guide_scale == 1: + noise_pred_cond, noise_pred_drop_audio = ret_values + noise_pred = noise_pred_cond + (audio_cfg_scale - 1)* adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_audio, + noise_pred_cond, + momentum_buffer=audio_momentumbuffer, + norm_threshold=apg_norm_threshold) + + else: + noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values + noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text, + noise_pred_cond, + momentum_buffer=text_momentumbuffer, + norm_threshold=apg_norm_threshold) \ + + (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond, + noise_pred_cond, + momentum_buffer=audio_momentumbuffer, + norm_threshold=apg_norm_threshold) else: - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond) - noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = None + if guide_scale == 1: + noise_pred_cond, noise_pred_drop_audio = ret_values + noise_pred = noise_pred_drop_audio + audio_cfg_scale* (noise_pred_cond - noise_pred_drop_audio) + else: + noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond) + noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = noise_pred_drop_audio = None else: noise_pred_cond, noise_pred_uncond = ret_values if apg_switch != 0: @@ -835,10 +1023,15 @@ class WanAny2V: latents_preview = latents if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] + if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) - callback(i, latents_preview[0], False) + callback(i, latents_preview[0], False, denoising_extra =denoising_extra ) latents_preview = None + clear() + if timestep_injection: + latents[:, :, :source_latents.shape[2]] = source_latents + if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: @@ -851,19 +1044,29 @@ class WanAny2V: videos = self.vae.decode(x0, VAE_tile_size) - if self.image_outputs: - videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0] + if image_outputs: + videos = torch.cat([video[:,:1] for video in videos], dim=1) if len(videos) > 1 else videos[0][:,:1] else: - videos = videos[0] # return only first video + videos = videos[0] # return only first video + if color_correction_strength > 0 and (prefix_frames_count > 0 and window_no > 1 or prefix_frames_count > 1 and window_no == 1): + if vace and False: + # videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "progressive_blend").squeeze(0) + videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0) + # videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), videos.unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0) + elif color_reference_frame is not None: + videos = match_and_blend_colors(videos.unsqueeze(0), color_reference_frame.unsqueeze(0), color_correction_strength).squeeze(0) + if return_latent_slice != None: return { "x" : videos, "latent_slice" : latent_slice } return videos - def adapt_vace_model(self): - model = self.model + def adapt_vace_model(self, model): modules_dict= { k: m for k, m in model.named_modules()} for model_layer, vace_layer in model.vace_layers_mapping.items(): module = modules_dict[f"vace_blocks.{vace_layer}"] target = modules_dict[f"blocks.{model_layer}"] setattr(target, "vace", module ) delattr(model, "vace_blocks") + + + diff --git a/wan/camera_extrinsics.json b/models/wan/camera_extrinsics.json similarity index 100% rename from wan/camera_extrinsics.json rename to models/wan/camera_extrinsics.json diff --git a/wan/configs/__init__.py b/models/wan/configs/__init__.py similarity index 100% rename from wan/configs/__init__.py rename to models/wan/configs/__init__.py diff --git a/wan/configs/shared_config.py b/models/wan/configs/shared_config.py similarity index 100% rename from wan/configs/shared_config.py rename to models/wan/configs/shared_config.py diff --git a/wan/configs/wan_i2v_14B.py b/models/wan/configs/wan_i2v_14B.py similarity index 95% rename from wan/configs/wan_i2v_14B.py rename to models/wan/configs/wan_i2v_14B.py index 7812c92..623a51f 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/models/wan/configs/wan_i2v_14B.py @@ -10,7 +10,7 @@ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -i2v_14B.t5_tokenizer = 'google/umt5-xxl' +i2v_14B.t5_tokenizer = 'umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' diff --git a/wan/configs/wan_t2v_14B.py b/models/wan/configs/wan_t2v_14B.py similarity index 94% rename from wan/configs/wan_t2v_14B.py rename to models/wan/configs/wan_t2v_14B.py index 9d0ee69..f422d1f 100644 --- a/wan/configs/wan_t2v_14B.py +++ b/models/wan/configs/wan_t2v_14B.py @@ -10,7 +10,7 @@ t2v_14B.update(wan_shared_cfg) # t5 t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_14B.t5_tokenizer = 'google/umt5-xxl' +t2v_14B.t5_tokenizer = 'umt5-xxl' # vae t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' diff --git a/wan/configs/wan_t2v_1_3B.py b/models/wan/configs/wan_t2v_1_3B.py similarity index 94% rename from wan/configs/wan_t2v_1_3B.py rename to models/wan/configs/wan_t2v_1_3B.py index ea9502b..ac23bff 100644 --- a/wan/configs/wan_t2v_1_3B.py +++ b/models/wan/configs/wan_t2v_1_3B.py @@ -10,7 +10,7 @@ t2v_1_3B.update(wan_shared_cfg) # t5 t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' +t2v_1_3B.t5_tokenizer = 'umt5-xxl' # vae t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py new file mode 100644 index 0000000..bc79e2e --- /dev/null +++ b/models/wan/df_handler.py @@ -0,0 +1,99 @@ +import torch + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + if base_model_type == "sky_df_1.3B": + coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + else: + coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + + skip_steps_cache.coefficients = coefficients + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + if base_model_type in ["sky_df_14B"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] =fps + extra_model_def["frames_minimum"] = 17 + extra_model_def["frames_steps"] = 20 + extra_model_def["sliding_window"] = True + extra_model_def["skip_layer_guidance"] = True + extra_model_def["tea_cache"] = True + extra_model_def["guidance_max_phases"] = 1 + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["sky_df_1.3B", "sky_df_14B"] + + + @staticmethod + def query_family_maps(): + models_eqv_map = { + "sky_df_1.3B" : "sky_df_14B", + } + + models_comp_map = { + "sky_df_14B": ["sky_df_1.3B"], + } + return models_eqv_map, models_comp_map + + + + @staticmethod + def query_model_family(): + return "wan" + + @staticmethod + def query_family_infos(): + return {} + + + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + from .wan_handler import family_handler + return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + from .configs import WAN_CONFIGS + from .wan_handler import family_handler + cfg = WAN_CONFIGS['t2v-14B'] + from . import DTT2V + wan_model = DTT2V( + config=cfg, + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } + return wan_model, pipe + + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "guidance_scale": 6.0, + "flow_shift": 8, + "sliding_window_discard_last_frames" : 0, + "resolution": "1280x720" if "720" in base_model_type else "960x544", + "sliding_window_size" : 121 if "720" in base_model_type else 97, + "RIFLEx_setting": 2, + "guidance_scale": 6, + "flow_shift": 8, + }) \ No newline at end of file diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py new file mode 100644 index 0000000..753fd45 --- /dev/null +++ b/models/wan/diffusion_forcing copy.py @@ -0,0 +1,479 @@ +import math +import os +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import logging +import numpy as np +import torch +from diffusers.image_processor import PipelineImageInput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from tqdm import tqdm +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from wan.modules.posemb_layers import get_rotary_pos_embed +from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +class DTT2V: + + + def __init__( + self, + config, + checkpoint_dir, + rank=0, + model_filename = None, + text_encoder_filename = None, + quantizeTransformer = False, + dtype = torch.bfloat16, + ): + self.device = torch.device(f"cuda") + self.config = config + self.rank = rank + self.dtype = dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=text_encoder_filename, + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn= None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating WanModel from {model_filename}") + from mmgp import offload + + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json") + # offload.load_model_data(self.model, "recam.ckpt") + # self.model.cpu() + # offload.save_model(self.model, "recam.safetensors") + if self.dtype == torch.float16 and not "fp16" in model_filename: + self.model.to(self.dtype) + # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) + if self.dtype == torch.float16: + self.vae.model.to(self.dtype) + self.model.eval().requires_grad_(False) + + self.scheduler = FlowUniPCMultistepScheduler() + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1 + + def encode_image( + self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # prefix_video + prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) + prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) + if prefix_video.dtype == torch.uint8: + prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 + prefix_video = prefix_video.to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + return prefix_video, predix_video_latent_length + + def prepare_latents( + self, + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ) -> torch.Tensor: + return randn_tensor(shape, generator, device=device, dtype=dtype) + + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @torch.no_grad() + def generate( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = "", + image: PipelineImageInput = None, + height: int = 480, + width: int = 832, + num_frames: int = 97, + num_inference_steps: int = 50, + shift: float = 1.0, + guidance_scale: float = 5.0, + seed: float = 0.0, + overlap_history: int = 17, + addnoise_condition: int = 0, + base_num_frames: int = 97, + ar_step: int = 5, + causal_block_size: int = 1, + causal_attention: bool = False, + fps: int = 24, + VAE_tile_size = 0, + joint_pass = False, + callback = None, + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + # if base_num_frames > base_num_frames: + # causal_block_size = 0 + self._guidance_scale = guidance_scale + + i2v_extra_kwrags = {} + prefix_video = None + predix_video_latent_length = 0 + if image: + frame_width, frame_height = image.size + scale = min(height / frame_height, width / frame_width) + height = (int(frame_height * scale) // 16) * 16 + width = (int(frame_width * scale) // 16) * 16 + + prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size) + + latent_length = (num_frames - 1) // 4 + 1 + latent_height = height // 8 + latent_width = width // 8 + + prompt_embeds = self.text_encoder([prompt], self.device) + prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds] + if self.do_classifier_free_guidance: + negative_prompt_embeds = self.text_encoder([negative_prompt], self.device) + negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds] + + + + self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + init_timesteps = self.scheduler.timesteps + fps_embeds = [fps] * prompt_embeds[0].shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + transformer_dtype = self.dtype + # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: + # short video generation + latent_shape = [16, latent_length, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=torch.float32, device=self.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + ) + sample_schedulers = [] + for _ in range(latent_length): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + + if callback != None: + callback(-1, None, True) + + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + # "causal_block_size" : causal_block_size, + "callback" : callback, + "pipeline" : self + } + kwrags.update(i2v_extra_kwrags) + + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=prompt_embeds, + context2=negative_prompt_embeds, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=negative_prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0], False) + + x0 = latents[0].unsqueeze(0) + videos = self.vae.decode(x0, tile_size= VAE_tile_size) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + videos = [video.cpu().numpy().astype(np.uint8) for video in videos] + return videos + else: + # long video generation + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + overlap_history_frames = (overlap_history - 1) // 4 + 1 + n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + print(f"n_iter:{n_iter}") + output_video = None + for i in range(n_iter): + if output_video is not None: # i !=0 + prefix_video = output_video[:, -overlap_history:].to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = latent_length - finished_frame_num + base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) + else: # i == 0 + base_num_frames_iter = base_num_frames + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=torch.float32, device=self.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + init_timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + if callback != None: + callback(-1, None, True) + + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + "causal_block_size" : causal_block_size, + "causal_attention" : causal_attention, + "callback" : callback, + "pipeline" : self + } + kwrags.update(i2v_extra_kwrags) + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=prompt_embeds, + context2=negative_prompt_embeds, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=prompt_embeds, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=negative_prompt_embeds, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0].squeeze(0), False) + + x0 = latents[0].unsqueeze(0) + videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + return output_video diff --git a/wan/diffusion_forcing.py b/models/wan/diffusion_forcing.py similarity index 89% rename from wan/diffusion_forcing.py rename to models/wan/diffusion_forcing.py index 85343fd..5d5e42f 100644 --- a/wan/diffusion_forcing.py +++ b/models/wan/diffusion_forcing.py @@ -14,12 +14,12 @@ from tqdm import tqdm from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from wan.modules.posemb_layers import get_rotary_pos_embed -from wan.utils.utils import calculate_new_dimensions -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, +from .modules.posemb_layers import get_rotary_pos_embed +from shared.utils.utils import calculate_new_dimensions +from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.utils.utils import update_loras_slists +from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from shared.utils.loras_mutipliers import update_loras_slists class DTT2V: @@ -199,7 +199,6 @@ class DTT2V: self, input_prompt: Union[str, List[str]], n_prompt: Union[str, List[str]] = "", - image_start: PipelineImageInput = None, input_video = None, height: int = 480, width: int = 832, @@ -211,7 +210,7 @@ class DTT2V: guide_scale: float = 5.0, seed: float = 0.0, overlap_noise: int = 0, - ar_step: int = 5, + model_mode: int = 5, causal_block_size: int = 5, causal_attention: bool = True, fps: int = 24, @@ -231,7 +230,7 @@ class DTT2V: if frame_num > 1: frame_num = max(17, frame_num) # must match causal_block_size for value of 5 frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) - + ar_step = model_mode if ar_step == 0: causal_block_size = 1 causal_attention = False @@ -242,11 +241,6 @@ class DTT2V: if input_video != None: _ , _ , height, width = input_video.shape - elif image_start != None: - image_start = image_start - frame_width, frame_height = image_start.size - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas) - image_start = np.array(image_start.resize((width, height))).transpose(2, 0, 1) latent_length = (frame_num - 1) // 4 + 1 @@ -276,18 +270,8 @@ class DTT2V: output_video = input_video - if image_start is not None or output_video is not None: # i !=0 - if output_video is not None: - prefix_video = output_video.to(self.device) - else: - causal_block_size = 1 - causal_attention = False - ar_step = 0 - prefix_video = image_start - prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) - if prefix_video.dtype == torch.uint8: - prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 - prefix_video = prefix_video.to(self.device) + if output_video is not None: # i !=0 + prefix_video = output_video.to(self.device) prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)] predix_video_latent_length = prefix_video.shape[1] truncate_len = predix_video_latent_length % causal_block_size @@ -329,21 +313,24 @@ class DTT2V: if callback != None: update_loras_slists(self.model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) - if self.model.enable_cache == "tea": - x_count = 2 if self.do_classifier_free_guidance else 1 - self.model.previous_residual = [None] * x_count - time_steps_comb = [] - self.model.num_steps = updated_num_steps - for i, timestep_i in enumerate(step_matrix): - valid_interval_start, valid_interval_end = valid_interval[i] - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: - timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise - time_steps_comb.append(timestep) - self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.cache_multiplier) - del time_steps_comb - else: - self.model.enable_cache = None + skip_steps_cache = self.model.cache + if skip_steps_cache != None: + skip_steps_cache.num_steps = updated_num_steps + if skip_steps_cache.cache_type == "tea": + x_count = 2 if self.do_classifier_free_guidance else 1 + skip_steps_cache.previous_residual = [None] * x_count + time_steps_comb = [] + skip_steps_cache.steps = updated_num_steps + for i, timestep_i in enumerate(step_matrix): + valid_interval_start, valid_interval_end = valid_interval[i] + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: + timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise + time_steps_comb.append(timestep) + self.model.compute_teacache_threshold(skip_steps_cache.start_step, time_steps_comb, skip_steps_cache.multiplier) + del time_steps_comb + else: + self.model.cache = None from mmgp import offload freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False) kwrags = { @@ -446,3 +433,4 @@ class DTT2V: videos = videos[0] # return only first video return videos + diff --git a/ltx_video/utils/__init__.py b/models/wan/distributed/__init__.py similarity index 100% rename from ltx_video/utils/__init__.py rename to models/wan/distributed/__init__.py diff --git a/wan/distributed/fsdp.py b/models/wan/distributed/fsdp.py similarity index 100% rename from wan/distributed/fsdp.py rename to models/wan/distributed/fsdp.py diff --git a/wan/distributed/xdit_context_parallel.py b/models/wan/distributed/xdit_context_parallel.py similarity index 100% rename from wan/distributed/xdit_context_parallel.py rename to models/wan/distributed/xdit_context_parallel.py diff --git a/wan/fantasytalking/infer.py b/models/wan/fantasytalking/infer.py similarity index 91% rename from wan/fantasytalking/infer.py rename to models/wan/fantasytalking/infer.py index 80d1945..d96bea0 100644 --- a/wan/fantasytalking/infer.py +++ b/models/wan/fantasytalking/infer.py @@ -6,7 +6,7 @@ from .model import FantasyTalkingAudioConditionModel from .utils import get_audio_features import gc, torch -def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): +def parse_audio(audio_path, start_frame, num_frames, fps = 23, device = "cuda"): fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) from mmgp import offload from accelerate import init_empty_weights @@ -24,7 +24,7 @@ def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False) wav2vec.to(device) proj_model.to(device) - audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames ) + audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, start_frame, num_frames) audio_proj_fea = proj_model(audio_wav2vec_fea) pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames ) diff --git a/wan/fantasytalking/model.py b/models/wan/fantasytalking/model.py similarity index 99% rename from wan/fantasytalking/model.py rename to models/wan/fantasytalking/model.py index 5ec3655..d0eb74d 100644 --- a/wan/fantasytalking/model.py +++ b/models/wan/fantasytalking/model.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from wan.modules.attention import pay_attention +from shared.attention import pay_attention class AudioProjModel(nn.Module): diff --git a/wan/fantasytalking/utils.py b/models/wan/fantasytalking/utils.py similarity index 79% rename from wan/fantasytalking/utils.py rename to models/wan/fantasytalking/utils.py index e044934..51f6678 100644 --- a/wan/fantasytalking/utils.py +++ b/models/wan/fantasytalking/utils.py @@ -26,13 +26,18 @@ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): writer.close() -def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames): +def get_audio_features(wav2vec, audio_processor, audio_path, fps, start_frame, num_frames): sr = 16000 - audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz + audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz start_time = 0 + if start_frame < 0: + pad = int(abs(start_frame)/ fps * sr) + audio_input = np.concatenate([np.zeros(pad), audio_input]) + end_frame = num_frames + else: + end_frame = start_frame + num_frames - start_time = 0 - # end_time = (0 + (num_frames - 1) * 1) / fps - end_time = num_frames / fps + start_time = start_frame / fps + end_time = end_frame / fps start_sample = int(start_time * sr) end_sample = int(end_time * sr) diff --git a/wan/modules/__init__.py b/models/wan/modules/__init__.py similarity index 81% rename from wan/modules/__init__.py rename to models/wan/modules/__init__.py index 38c29ce..56aea65 100644 --- a/wan/modules/__init__.py +++ b/models/wan/modules/__init__.py @@ -1,8 +1,9 @@ -from .attention import pay_attention +from shared.attention import pay_attention from .model import WanModel from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer from .vae import WanVAE +from .vae2_2 import Wan2_2_VAE __all__ = [ 'WanVAE', diff --git a/wan/modules/clip.py b/models/wan/modules/clip.py similarity index 99% rename from wan/modules/clip.py rename to models/wan/modules/clip.py index da91a00..fc29893 100644 --- a/wan/modules/clip.py +++ b/models/wan/modules/clip.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T -from .attention import pay_attention +from shared.attention import pay_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta diff --git a/wan/modules/model.py b/models/wan/modules/model.py similarity index 84% rename from wan/modules/model.py rename to models/wan/modules/model.py index 7d6357d..95faa4d 100644 --- a/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -12,13 +12,29 @@ from diffusers.models.modeling_utils import ModelMixin import numpy as np from typing import Union,Optional from mmgp import offload -from .attention import pay_attention +from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel -from wan.multitalk.multitalk_utils import get_attn_map_with_target +from ..multitalk.multitalk_utils import get_attn_map_with_target __all__ = ['WanModel'] +def get_cache(cache_name): + all_cache = offload.shared_state.get("_cache", None) + if all_cache is None: + all_cache = {} + offload.shared_state["_cache"]= all_cache + cache = offload.shared_state.get(cache_name, None) + if cache is None: + cache = {} + offload.shared_state[cache_name] = cache + return cache + +def clear_caches(): + all_cache = offload.shared_state.get("_cache", None) + if all_cache is not None: + all_cache.clear() + def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -93,6 +109,32 @@ def relative_l1_distance(last_tensor, current_tensor): relative_l1_distance = l1_distance / norm return relative_l1_distance.to(torch.float32) +class LoRALinearLayer(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 128, + dtype: Optional[torch.dtype] = torch.float32, + ): + super().__init__() + self.down = nn.Linear(in_features, rank, bias=False, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, dtype=dtype) + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + return up_hidden_states.to(orig_dtype) + class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): @@ -150,7 +192,7 @@ class WanLayerNorm(nn.LayerNorm): return x # return super().forward(x).type_as(x) -from wan.modules.posemb_layers import apply_rotary_emb +from .posemb_layers import apply_rotary_emb class WanSelfAttention(nn.Module): @@ -244,7 +286,7 @@ class WanSelfAttention(nn.Module): else: return x, None - def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0): + def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -257,19 +299,28 @@ class WanSelfAttention(nn.Module): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function - q = self.q(x) + q, k, v = self.q(x), self.k(x), self.v(x) + if standin_phase == 1: + q += self.q_loras(x) + k += self.k_loras(x) + v += self.v_loras(x) self.norm_q(q) - q = q.view(b, s, n, d) - k = self.k(x) self.norm_k(k) - k = k.view(b, s, n, d) - v = self.v(x).view(b, s, n, d) + q,k,v = q.view(b, s, n, d), k.view(b, s, n, d), v.view(b, s, n, d) del x qklist = [q,k] del q,k q,k = apply_rotary_emb(qklist, freqs, head_first=False) + if standin_phase >= 1: + standin_cache = get_cache("standin") + if standin_phase == 1: + standin_cache[self.block_no] = (k,v) + elif standin_phase == 2: + k_ip, v_ip = standin_cache[self.block_no] + k, v = torch.concat([k, k_ip], dim=1), torch.concat([v, v_ip], dim=1) + del k_ip, v_ip if ref_target_masks != None: x_ref_attn_map = get_attn_map_with_target(q, k , grid_sizes, ref_target_masks=ref_target_masks, ref_images_count = ref_images_count) else: @@ -289,6 +340,7 @@ class WanSelfAttention(nn.Module): x = pay_attention( qkv_list, window_size=self.window_size) + else: with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): x = ( @@ -429,7 +481,7 @@ class WanAttentionBlock(nn.Module): self.block_id = block_id if output_dim > 0: - from wan.multitalk.attention import SingleStreamMutiAttention + from ..multitalk.attention import SingleStreamMutiAttention # init audio module self.audio_cross_attn = SingleStreamMutiAttention( dim=dim, @@ -461,6 +513,7 @@ class WanAttentionBlock(nn.Module): multitalk_audio=None, multitalk_masks=None, ref_images_count=0, + standin_phase=-1, ): r""" Args: @@ -504,7 +557,7 @@ class WanAttentionBlock(nn.Module): xlist = [x_mod.to(attention_dtype)] del x_mod - y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count) + y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, ) y = y.to(dtype) if cam_emb != None: y = self.projector(y) @@ -513,11 +566,13 @@ class WanAttentionBlock(nn.Module): x.addcmul_(y, e[2]) x, y = restore_latent_shape(x), restore_latent_shape(y) del y - y = self.norm3(x) - y = y.to(attention_dtype) - ylist= [y] - del y - x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) + + if context is not None: + y = self.norm3(x) + y = y.to(attention_dtype) + ylist= [y] + del y + x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) if multitalk_audio != None: # cross attn of multitalk audio @@ -552,7 +607,7 @@ class WanAttentionBlock(nn.Module): y_shape = y.shape y = y.view(-1, y_shape[-1]) - chunk_size = int(y_shape[1]/2.7) + chunk_size = int(y.shape[0]/2.7) chunks =torch.split(y, chunk_size) for y_chunk in chunks: mlp_chunk = ffn(y_chunk) @@ -766,7 +821,22 @@ class WanModel(ModelMixin, ConfigMixin): first = next(iter(sd), None) if first == None: return sd - + + new_sd = {} + + # for k,v in sd.items(): + # if k.endswith("modulation.diff"): + # pass + # else: + # new_sd[ k] = v + # sd = new_sd + + # if first.startswith("blocks."): + # new_sd = {} + # for k,v in sd.items(): + # new_sd["diffusion_model." + k] = v + # sd = new_sd + if first.startswith("lora_unet_"): new_sd = {} print("Converting Lora Safetensors format to Lora Diffusers format") @@ -789,7 +859,7 @@ class WanModel(ModelMixin, ConfigMixin): sd = new_sd from wgp import test_class_i2v - if not test_class_i2v(model_type): + if not test_class_i2v(model_type) or model_type in ["i2v_2_2"]: new_sd = {} # convert loras for i2v to t2v for k,v in sd.items(): @@ -838,11 +908,11 @@ class WanModel(ModelMixin, ConfigMixin): vae_scale=4, # vae timedownsample scale norm_input_visual=True, norm_output_audio=True, + standin= False, ): super().__init__() - assert model_type in ['t2v', 'i2v'] self.model_type = model_type self.patch_size = patch_size @@ -889,7 +959,7 @@ class WanModel(ModelMixin, ConfigMixin): # blocks if vace_layers == None: - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + cross_attn_type = 't2v_cross_attn' if model_type in ['t2v','i2v2_2', 'ti2v2_2'] else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, block_no =i, output_dim=multitalk_output_dim, norm_input_visual=norm_input_visual) @@ -958,10 +1028,15 @@ class WanModel(ModelMixin, ConfigMixin): block.projector.bias = nn.Parameter(torch.zeros(dim)) if fantasytalking_dim > 0: - from wan.fantasytalking.model import WanCrossAttentionProcessor + from ..fantasytalking.model import WanCrossAttentionProcessor for block in self.blocks: block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim) + if standin: + for block in self.blocks: + block.self_attn.q_loras = LoRALinearLayer(dim, dim, rank=128) + block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) + block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] @@ -1005,6 +1080,7 @@ class WanModel(ModelMixin, ConfigMixin): self._lock_dtype = dtype def compute_magcache_threshold(self, start_step, timesteps = None, speed_factor =0): + skips_step_cache = self.cache def nearest_interp(src_array, target_length): src_length = len(src_array) if target_length == 1: return np.array([src_array[-1]]) @@ -1012,13 +1088,14 @@ class WanModel(ModelMixin, ConfigMixin): mapped_indices = np.round(np.arange(target_length) * scale).astype(int) return src_array[mapped_indices] num_inference_steps = len(timesteps) - if len(self.def_mag_ratios) != num_inference_steps*2: - mag_ratio_con = nearest_interp(self.def_mag_ratios[0::2], num_inference_steps) - mag_ratio_ucon = nearest_interp(self.def_mag_ratios[1::2], num_inference_steps) + def_mag_ratios = np.array([1.0]*2+ skips_step_cache.def_mag_ratios) + if len(def_mag_ratios) != num_inference_steps*2: + mag_ratio_con = nearest_interp(def_mag_ratios[0::2], num_inference_steps) + mag_ratio_ucon = nearest_interp(def_mag_ratios[1::2], num_inference_steps) interpolated_mag_ratios = np.concatenate([mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1) - self.mag_ratios = interpolated_mag_ratios + skips_step_cache.mag_ratios = interpolated_mag_ratios else: - self.mag_ratios = self.def_mag_ratios + skips_step_cache.mag_ratios = def_mag_ratios best_deltas = None @@ -1039,12 +1116,12 @@ class WanModel(ModelMixin, ConfigMixin): else: x_should_calc = [] for cur_x_id in range(x_id_max): - cur_mag_ratio = self.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list + cur_mag_ratio = skips_step_cache.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_steps[cur_x_id] += 1 # skip steps plus 1 cur_skip_err = np.abs(1-accumulated_ratio[cur_x_id]) # skip error of current steps accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps - if accumulated_err[cur_x_id] best_diff: break threshold += 0.01 - self.magcache_thresh = best_threshold + skips_step_cache.magcache_thresh = best_threshold print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): + skips_step_cache = self.cache modulation_dtype = self.time_projection[1].weight.dtype - rescale_func = np.poly1d(self.coefficients) + rescale_func = np.poly1d(skips_step_cache.coefficients) e_list = [] for t in timesteps: t = torch.stack([t]) @@ -1107,7 +1185,7 @@ class WanModel(ModelMixin, ConfigMixin): elif diff > best_diff: break threshold += 0.01 - self.rel_l1_thresh = best_threshold + skips_step_cache.rel_l1_thresh = best_threshold print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") # print(f"deltas:{best_deltas}") return best_threshold @@ -1138,8 +1216,9 @@ class WanModel(ModelMixin, ConfigMixin): audio_scale=None, multitalk_audio = None, multitalk_masks = None, - ref_images_count = 0, - + ref_images_count = 0, + standin_freqs = None, + standin_ref = None, ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1203,6 +1282,18 @@ class WanModel(ModelMixin, ConfigMixin): offload.shared_state["embed_sizes"] = grid_sizes offload.shared_state["step_no"] = current_step offload.shared_state["max_steps"] = max_steps + if current_step == 0 and x_id == 0: clear_caches() + # arguments + + kwargs = dict( + grid_sizes=grid_sizes, + freqs=freqs, + cam_emb = cam_emb, + block_mask = block_mask, + audio_proj=audio_proj, + audio_context_lens=audio_context_lens, + ref_images_count=ref_images_count, + ) _flag_df = t.dim() == 2 @@ -1211,6 +1302,16 @@ class WanModel(ModelMixin, ConfigMixin): ) # b, dim e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) + standin_x = None + if standin_ref is not None: + standin_cache_enabled = False + kwargs["standin_phase"] = 2 + if (current_step == 0 or not standin_cache_enabled) and x_id == 0: + standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) + standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) + standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) + standin_e = standin_ref = None + if self.inject_sample_info and fps!=None: fps = torch.tensor(fps, dtype=torch.long, device=device) @@ -1237,8 +1338,9 @@ class WanModel(ModelMixin, ConfigMixin): if multitalk_audio != None: multitalk_audio_list = [] for audio in multitalk_audio: - audio = self.audio_proj(*audio) - audio = torch.concat(audio.split(1), dim=2).to(context[0]) + if audio is not None: + audio = self.audio_proj(*audio) + audio = torch.concat(audio.split(1), dim=2).to(context[0]) multitalk_audio_list.append(audio) audio = None else: @@ -1254,17 +1356,6 @@ class WanModel(ModelMixin, ConfigMixin): else: audio_scale_list = [None] * len(x_list) - # arguments - - kwargs = dict( - grid_sizes=grid_sizes, - freqs=freqs, - cam_emb = cam_emb, - block_mask = block_mask, - audio_proj=audio_proj, - audio_context_lens=audio_context_lens, - ref_images_count=ref_images_count, - ) if vace_context == None: hints_list = [None ] *len(x_list) @@ -1277,72 +1368,73 @@ class WanModel(ModelMixin, ConfigMixin): del c should_calc = True x_should_calc = None - if self.enable_cache != None: - if self.enable_cache == "mag": - if current_step <= self.cache_start_step: + skips_steps_cache = self.cache + if skips_steps_cache != None: + if skips_steps_cache.cache_type == "mag": + if current_step <= skips_steps_cache.start_step: should_calc = True - elif self.one_for_all and x_id != 0: # not joint pass, not main pas, one for all + elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all assert len(x_list) == 1 - should_calc = self.should_calc + should_calc = skips_steps_cache.should_calc else: x_should_calc = [] - for i in range(1 if self.one_for_all else len(x_list)): + for i in range(1 if skips_steps_cache.one_for_all else len(x_list)): cur_x_id = i if joint_pass else x_id - cur_mag_ratio = self.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list - self.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step - self.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 - cur_skip_err = np.abs(1-self.accumulated_ratio[cur_x_id]) # skip error of current steps - self.accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps - if self.accumulated_err[cur_x_id] 0: - return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: - return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] def decode(self, zs, tile_size, any_end_frame = False): + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0: - return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: - return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py new file mode 100644 index 0000000..c1a88f5 --- /dev/null +++ b/models/wan/modules/vae2_2.py @@ -0,0 +1,1211 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + "Wan2_2_VAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + cache_x = None + x = F.pad(x, padding) + try: + out = super().forward(x) + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be " + "used for this decoding (which is very slow). Consider using tiled VAE Decoding.") + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # avoid double padding here + self.dilation, self.groups) + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + return out + raise +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return (F.normalize(x, dim=(1 if self.channel_first else -1)) * + self.scale * self.gamma + self.bias) + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = ( + self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk(3, dim=-1)) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + _offload_hooks = ['encode', 'decode'] + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale = None, any_end_frame = False): + self.clear_cache() + x = patchify(x, patch_size=2) + ## cache + t = x.shape[2] + if any_end_frame: + iter_ = 2 + (t - 2) // 4 + else: + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + out_list = [] + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out_list.append(self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx)) + elif any_end_frame and i== iter_ -1: + out_list.append(self.encoder( + x[:, :, -1:, :, :], + feat_cache= None, + feat_idx=self._enc_conv_idx)) + else: + out_list.append(self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx)) + + self.clear_cache() + out = torch.cat(out_list, 2) + out_list = None + + mu, log_var = self.conv1(out).chunk(2, dim=1) + if scale != None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + return mu + + + def decode(self, z, scale=None, any_end_frame = False): + self.clear_cache() + # z: [b,c,t,h,w] + if scale != None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + out_list = [] + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out_list.append(self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk = True) + ) + elif any_end_frame and i==iter_-1: + out_list.append(self.decoder( + x[:, :, -1:, :, :], + feat_cache=None , + feat_idx=self._conv_idx)) + else: + out_list.append(self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx)) + self.clear_cache() + out = torch.cat(out_list, 2) + out = unpatchify(out, patch_size=2) + return out + + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def spatial_tiled_decode(self, z, scale, tile_size, any_end_frame= False): + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 16) + tile_overlap_factor = 0.25 + + # z: [b,c,t,h,w] + + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) #8 0.75 + blend_extent = int(tile_sample_min_size * tile_overlap_factor) #256 0.25 + row_limit = tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] + decoded = self.decode(tile, any_end_frame= any_end_frame) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + return torch.cat(result_rows, dim=-2) + + + def spatial_tiled_encode(self, x, scale, tile_size, any_end_frame = False) : + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 16) + tile_overlap_factor = 0.25 + + overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor)) + blend_extent = int(tile_latent_min_size * tile_overlap_factor) + row_limit = tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] + tile = self.encode(tile, any_end_frame= any_end_frame) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + mu = torch.cat(result_rows, dim=-2) + + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + + return mu + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + from mmgp import offload + # load checkpoint + logging.info(f"loading {pretrained_path}") + # model.load_state_dict( + # torch.load(pretrained_path, map_location=device), assign=True) + # offload.save_model(model, "Wan_vae_2_2.safetensors") + # model.to(torch.bfloat16) + # offload.save_model(model, "Wan_vae_2_2_bf16.safetensors") + offload.load_model_data(model, pretrained_path.replace(".pth", ".safetensors"), writable_tensors= False) + + return model + + +class Wan2_2_VAE: + + def __init__( + self, + z_dim=48, + c_dim=160, + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dtype=torch.float, + device="cuda", + ): + + self.dtype = dtype + self.device = device + + mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + dtype=dtype, + device=device, + ) + std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], + dtype=dtype, + device=device, + ) + self.scale = [mean, 1.0 / std] + + # init model + self.model = ( + _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + ).eval().requires_grad_(False).to(device)) + + self.model._model_dtype = dtype + + + @staticmethod + def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision): + # VAE Tiling + if vae_config == 0: + if mixed_precision: + device_mem_capacity = device_mem_capacity / 2 + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + if use_vae_config == 1: + VAE_tile_size = 0 + elif use_vae_config == 2: + VAE_tile_size = 256 + else: + VAE_tile_size = 128 + + return VAE_tile_size + + def encode(self, videos, tile_size = 256, any_end_frame = False): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + scale = [u.to(device = self.device) for u in self.scale] + + if tile_size > 0 and False : + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + else: + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + + + def decode(self, zs, tile_size = 256, any_end_frame = False): + scale = [u.to(device = self.device) for u in self.scale] + if tile_size > 0 : + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + else: + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + diff --git a/wan/modules/xlm_roberta.py b/models/wan/modules/xlm_roberta.py similarity index 100% rename from wan/modules/xlm_roberta.py rename to models/wan/modules/xlm_roberta.py diff --git a/wan/multitalk/attention.py b/models/wan/multitalk/attention.py similarity index 99% rename from wan/multitalk/attention.py rename to models/wan/multitalk/attention.py index 12fb317..27d488f 100644 --- a/wan/multitalk/attention.py +++ b/models/wan/multitalk/attention.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from einops import rearrange, repeat from .multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids -from wan.modules.attention import pay_attention +from shared.attention import pay_attention # import xformers.ops diff --git a/wan/multitalk/kokoro/__init__.py b/models/wan/multitalk/kokoro/__init__.py similarity index 100% rename from wan/multitalk/kokoro/__init__.py rename to models/wan/multitalk/kokoro/__init__.py diff --git a/wan/multitalk/kokoro/__main__.py b/models/wan/multitalk/kokoro/__main__.py similarity index 100% rename from wan/multitalk/kokoro/__main__.py rename to models/wan/multitalk/kokoro/__main__.py diff --git a/wan/multitalk/kokoro/custom_stft.py b/models/wan/multitalk/kokoro/custom_stft.py similarity index 100% rename from wan/multitalk/kokoro/custom_stft.py rename to models/wan/multitalk/kokoro/custom_stft.py diff --git a/wan/multitalk/kokoro/istftnet.py b/models/wan/multitalk/kokoro/istftnet.py similarity index 100% rename from wan/multitalk/kokoro/istftnet.py rename to models/wan/multitalk/kokoro/istftnet.py diff --git a/wan/multitalk/kokoro/model.py b/models/wan/multitalk/kokoro/model.py similarity index 100% rename from wan/multitalk/kokoro/model.py rename to models/wan/multitalk/kokoro/model.py diff --git a/wan/multitalk/kokoro/modules.py b/models/wan/multitalk/kokoro/modules.py similarity index 100% rename from wan/multitalk/kokoro/modules.py rename to models/wan/multitalk/kokoro/modules.py diff --git a/wan/multitalk/kokoro/pipeline.py b/models/wan/multitalk/kokoro/pipeline.py similarity index 100% rename from wan/multitalk/kokoro/pipeline.py rename to models/wan/multitalk/kokoro/pipeline.py diff --git a/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py similarity index 87% rename from wan/multitalk/multitalk.py rename to models/wan/multitalk/multitalk.py index 038efdf..fbf9175 100644 --- a/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -7,10 +7,7 @@ import subprocess import torchvision.transforms as transforms import torch.nn.functional as F import torch.nn as nn -import wan -from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS -from wan.utils.utils import cache_image, cache_video, str2bool -# from wan.utils.multitalk_utils import save_video_ffmpeg +# from shared.utils.multitalk_utils import save_video_ffmpeg # from .kokoro import KPipeline from transformers import Wav2Vec2FeatureExtractor from .wav2vec2 import Wav2Vec2Model @@ -74,25 +71,52 @@ def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): return human_speech_array -def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0): +def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0, pad = 0, min_audio_duration = 0): if not (left_path==None or right_path==None): human_speech_array1 = audio_prepare_single(left_path, duration = duration) human_speech_array2 = audio_prepare_single(right_path, duration = duration) - elif left_path==None: - human_speech_array2 = audio_prepare_single(right_path, duration = duration) - human_speech_array1 = np.zeros(human_speech_array2.shape[0]) - elif right_path==None: - human_speech_array1 = audio_prepare_single(left_path, duration = duration) - human_speech_array2 = np.zeros(human_speech_array1.shape[0]) + else: + audio_type='para' + if left_path==None: + human_speech_array2 = audio_prepare_single(right_path, duration = duration) + human_speech_array1 = np.zeros(human_speech_array2.shape[0]) + elif right_path==None: + human_speech_array1 = audio_prepare_single(left_path, duration = duration) + human_speech_array2 = np.zeros(human_speech_array1.shape[0]) if audio_type=='para': new_human_speech1 = human_speech_array1 new_human_speech2 = human_speech_array2 + if len(new_human_speech1) != len(new_human_speech2): + if len(new_human_speech1) < len(new_human_speech2): + new_human_speech1 = np.pad(new_human_speech1, (0, len(new_human_speech2) - len(new_human_speech1))) + else: + new_human_speech2 = np.pad(new_human_speech2, (0, len(new_human_speech1) - len(new_human_speech2))) elif audio_type=='add': new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]]) + + + duration_changed = False + if min_audio_duration > 0: + min_samples = math.ceil( min_audio_duration * sample_rate) + if len(new_human_speech1) < min_samples: + new_human_speech1 = np.concatenate([new_human_speech1, np.zeros(min_samples -len(new_human_speech1)) ]) + duration_changed = True + if len(new_human_speech2) < min_samples: + new_human_speech2 = np.concatenate([new_human_speech2, np.zeros(min_samples -len(new_human_speech2)) ]) + duration_changed = True + + #dont include the padding on the summed audio which is used to build the output audio track sum_human_speechs = new_human_speech1 + new_human_speech2 - return new_human_speech1, new_human_speech2, sum_human_speechs + + if pad > 0: + duration_changed = True + new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1]) + new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2]) + + return new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed + def process_tts_single(text, save_dir, voice1): s1_sentences = [] @@ -167,19 +191,18 @@ def process_tts_multi(text, save_dir, voice1, voice2): return s1, s2, save_path_sum -def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000): +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") - - new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps) + pad = int(padded_frames_for_embeddings/ fps * sr) + new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - full_audio_embs = [] if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) - if audio_guide2 == None: sum_human_speechs = None + if audio_guide2 == None and not duration_changed: sum_human_speechs = None return full_audio_embs, sum_human_speechs diff --git a/wan/multitalk/multitalk_model.py b/models/wan/multitalk/multitalk_model.py similarity index 100% rename from wan/multitalk/multitalk_model.py rename to models/wan/multitalk/multitalk_model.py diff --git a/models/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py new file mode 100644 index 0000000..6e2b2c3 --- /dev/null +++ b/models/wan/multitalk/multitalk_utils.py @@ -0,0 +1,882 @@ +import os +from einops import rearrange + +import torch +import torch.nn as nn + +from einops import rearrange, repeat +from functools import lru_cache +import imageio +import uuid +from tqdm import tqdm +import numpy as np +import subprocess +import soundfile as sf +import torchvision +import binascii +import os.path as osp +from skimage import color + + +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") +ASPECT_RATIO_627 = { + '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), + '0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), + '1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), + '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} + + +ASPECT_RATIO_960 = { + '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), + '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), + '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), + '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), + '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), + '3.75': ([1920, 512], 1)} + + + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + + +def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): + + S = T * token_frame + split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] + start = sum(split_sizes[:rank]) + end = start + split_sizes[rank] + counts = [0] * T + for idx in range(start, end): + t = idx // token_frame + counts[t] += 1 + + counts_filtered = [] + frame_ids = [] + for t, c in enumerate(counts): + if c > 0: + counts_filtered.append(c) + frame_ids.append(t) + return counts_filtered, frame_ids + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + + source_min, source_max = source_range + new_min, new_max = target_range + + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +# @torch.compile +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): + + ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q * scale + visual_q = visual_q.transpose(1, 2) + ref_k = ref_k.transpose(1, 2) + attn = visual_q @ ref_k.transpose(-2, -1) + + if attn_bias is not None: attn += attn_bias + + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + + x_ref_attn_maps = [] + ref_target_masks = ref_target_masks.to(visual_q.dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) + + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask[None, None, None, ...] + x_ref_attnmap = x_ref_attn_map_source * ref_target_mask + x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens + x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H + + if mode == 'mean': + x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens + elif mode == 'max': + x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens + + x_ref_attn_maps.append(x_ref_attnmap) + + del attn + del x_ref_attn_map_source + + return torch.concat(x_ref_attn_maps, dim=0) + + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + if ref_images_count > 0 : + visual_q_shape = visual_q.shape + visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1) + visual_q = visual_q[:, ref_images_count:] + visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:]) + + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead + + x_ref_attn_maps /= split_num + return x_ref_attn_maps + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding1D(nn.Module): + + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + + @lru_cache(maxsize=32) + def precompute_freqs_cis_1d(self, pos_indices): + + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + """1D RoPE. + + Args: + query (torch.tensor): [B, head, seq, head_dim] + pos_indices (torch.tensor): [seq,] + Returns: + query with the same shape as input. + """ + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"]) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + +def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): + + def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer( + save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params + ) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + save_path_tmp = save_path + "-temp.mp4" + + if high_quality_save: + cache_video( + tensor=gen_video_samples.unsqueeze(0), + save_file=save_path_tmp, + fps=fps, + nrow=1, + normalize=True, + value_range=(-1, 1) + ) + else: + video_audio = (gen_video_samples+1)/2 # C T H W + video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() + video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] + save_video(video_audio, save_path_tmp, fps=fps, quality=quality) + + + # crop audio according to video length + _, T, _, _ = gen_video_samples.shape + duration = T / fps + save_path_crop_audio = save_path + "-cropaudio.wav" + final_command = [ + "ffmpeg", + "-i", + vocal_audio_list[0], + "-t", + f'{duration}', + save_path_crop_audio, + ] + subprocess.run(final_command, check=True) + + save_path = save_path + ".mp4" + if high_quality_save: + final_command = [ + "ffmpeg", + "-y", + "-i", save_path_tmp, + "-i", save_path_crop_audio, + "-c:v", "libx264", + "-crf", "0", + "-preset", "veryslow", + "-c:a", "aac", + "-shortest", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + else: + final_command = [ + "ffmpeg", + "-y", + "-i", + save_path_tmp, + "-i", + save_path_crop_audio, + "-c:v", + "libx264", + "-c:a", + "aac", + "-shortest", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + + +def project( + v0: torch.Tensor, # [B, C, T, H, W] + v1: torch.Tensor, # [B, C, T, H, W] + ): + dtype = v0.dtype + v0, v1 = v0.double(), v1.double() + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel.to(dtype), v0_orthogonal.to(dtype) + + +def adaptive_projected_guidance( + diff: torch.Tensor, # [B, C, T, H, W] + pred_cond: torch.Tensor, # [B, C, T, H, W] + momentum_buffer: MomentumBuffer = None, + eta: float = 0.0, + norm_threshold: float = 55, + ): + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) + print(f"diff_norm: {diff_norm}") + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + diff_parallel, diff_orthogonal = project(diff, pred_cond) + normalized_update = diff_orthogonal + eta * diff_parallel + return normalized_update + +def match_and_blend_colors(source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float) -> torch.Tensor: + """ + Matches the color of a source video chunk to a reference image and blends with the original. + + Args: + source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1]. + Assumes B=1 (batch size of 1). + reference_image (torch.Tensor): The reference image (B, C, 1, H, W) in range [-1, 1]. + Assumes B=1 and T=1 (single reference frame). + strength (float): The strength of the color correction (0.0 to 1.0). + 0.0 means no correction, 1.0 means full correction. + + Returns: + torch.Tensor: The color-corrected and blended video chunk. + """ + # print(f"[match_and_blend_colors] Input source_chunk shape: {source_chunk.shape}, reference_image shape: {reference_image.shape}, strength: {strength}") + + if strength == 0.0: + # print(f"[match_and_blend_colors] Strength is 0, returning original source_chunk.") + return source_chunk + + if not 0.0 <= strength <= 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}") + + device = source_chunk.device + dtype = source_chunk.dtype + + # Squeeze batch dimension, permute to T, H, W, C for skimage + # Source: (1, C, T, H, W) -> (T, H, W, C) + source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() + # Reference: (1, C, 1, H, W) -> (H, W, C) + ref_np = reference_image.squeeze(0).squeeze(1).permute(1, 2, 0).cpu().numpy() # Squeeze T dimension as well + + # Normalize from [-1, 1] to [0, 1] for skimage + source_np_01 = (source_np + 1.0) / 2.0 + ref_np_01 = (ref_np + 1.0) / 2.0 + + # Clip to ensure values are strictly in [0, 1] after potential float precision issues + source_np_01 = np.clip(source_np_01, 0.0, 1.0) + ref_np_01 = np.clip(ref_np_01, 0.0, 1.0) + + # Convert reference to Lab + try: + ref_lab = color.rgb2lab(ref_np_01) + except ValueError as e: + # Handle potential errors if image data is not valid for conversion + print(f"Warning: Could not convert reference image to Lab: {e}. Skipping color correction for this chunk.") + return source_chunk + + + corrected_frames_np_01 = [] + for i in range(source_np_01.shape[0]): # Iterate over time (T) + source_frame_rgb_01 = source_np_01[i] + + try: + source_lab = color.rgb2lab(source_frame_rgb_01) + except ValueError as e: + print(f"Warning: Could not convert source frame {i} to Lab: {e}. Using original frame.") + corrected_frames_np_01.append(source_frame_rgb_01) + continue + + corrected_lab_frame = source_lab.copy() + + # Perform color transfer for L, a, b channels + for j in range(3): # L, a, b + mean_src, std_src = source_lab[:, :, j].mean(), source_lab[:, :, j].std() + mean_ref, std_ref = ref_lab[:, :, j].mean(), ref_lab[:, :, j].std() + + # Avoid division by zero if std_src is 0 + if std_src == 0: + # If source channel has no variation, keep it as is, but shift by reference mean + # This case is debatable, could also just copy source or target mean. + # Shifting by target mean helps if source is flat but target isn't. + corrected_lab_frame[:, :, j] = mean_ref + else: + corrected_lab_frame[:, :, j] = (corrected_lab_frame[:, :, j] - mean_src) * (std_ref / std_src) + mean_ref + + try: + fully_corrected_frame_rgb_01 = color.lab2rgb(corrected_lab_frame) + except ValueError as e: + print(f"Warning: Could not convert corrected frame {i} back to RGB: {e}. Using original frame.") + corrected_frames_np_01.append(source_frame_rgb_01) + continue + + # Clip again after lab2rgb as it can go slightly out of [0,1] + fully_corrected_frame_rgb_01 = np.clip(fully_corrected_frame_rgb_01, 0.0, 1.0) + + # Blend with original source frame (in [0,1] RGB) + blended_frame_rgb_01 = (1 - strength) * source_frame_rgb_01 + strength * fully_corrected_frame_rgb_01 + corrected_frames_np_01.append(blended_frame_rgb_01) + + corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0) + + # Convert back to [-1, 1] + corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0 + + # Permute back to (C, T, H, W), add batch dim, and convert to original torch.Tensor type and device + # (T, H, W, C) -> (C, T, H, W) + corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0) + corrected_chunk_tensor = corrected_chunk_tensor.contiguous() # Ensure contiguous memory layout + output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype) + # print(f"[match_and_blend_colors] Output tensor shape: {output_tensor.shape}") + return output_tensor + + +from skimage import color +from scipy import ndimage +from scipy.ndimage import binary_erosion, distance_transform_edt + + +def match_and_blend_colors_with_mask( + source_chunk: torch.Tensor, + reference_video: torch.Tensor, + mask: torch.Tensor, + strength: float, + copy_mode: str = "corrected", # "corrected", "reference", "source", "progressive_blend" + source_border_distance: int = 10, + reference_border_distance: int = 10 +) -> torch.Tensor: + """ + Matches the color of a source video chunk to a reference video using mask-based region sampling. + + Args: + source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1]. + Assumes B=1 (batch size of 1). + reference_video (torch.Tensor): The reference video (B, C, T, H, W) in range [-1, 1]. + Must have same temporal dimension as source_chunk. + mask (torch.Tensor): Binary mask (B, 1, T, H, W) or (T, H, W) or (H, W) with values 0 and 1. + Color correction is applied to pixels where mask=1. + strength (float): The strength of the color correction (0.0 to 1.0). + 0.0 means no correction, 1.0 means full correction. + copy_mode (str): What to do with mask=0 pixels: + "corrected" (keep original), "reference", "source", + "progressive_blend" (double-sided progressive blending near borders). + source_border_distance (int): Distance in pixels from mask border to sample source video (mask=1 side). + reference_border_distance (int): Distance in pixels from mask border to sample reference video (mask=0 side). + For "progressive_blend" mode, this also defines the blending falloff distance. + + Returns: + torch.Tensor: The color-corrected and blended video chunk. + + Notes: + - Color statistics are sampled from border regions to determine source and reference tints + - Progressive blending creates smooth double-sided transitions: + * mask=1 side: 60% source + 40% reference at border → 100% source deeper in + * mask=0 side: 60% reference + 40% source at border → 100% reference deeper in + """ + + if strength == 0.0: + return source_chunk + + if not 0.0 <= strength <= 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}") + + if copy_mode not in ["corrected", "reference", "source", "progressive_blend"]: + raise ValueError(f"copy_mode must be 'corrected', 'reference', 'source', or 'progressive_blend', got {copy_mode}") + + device = source_chunk.device + dtype = source_chunk.dtype + B, C, T, H, W = source_chunk.shape + + # Handle different mask dimensions + if mask.dim() == 2: # (H, W) + mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W) + elif mask.dim() == 3: # (T, H, W) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W) + elif mask.dim() == 4: # (B, T, H, W) - missing channel dim + mask = mask.unsqueeze(1) + # mask should now be (B, 1, T, H, W) + + # Convert to numpy for processing + source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C) + reference_np = reference_video.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C) + mask_np = mask.squeeze(0).squeeze(0).cpu().numpy() # (T, H, W) + + # Normalize from [-1, 1] to [0, 1] for skimage + source_np_01 = (source_np + 1.0) / 2.0 + reference_np_01 = (reference_np + 1.0) / 2.0 + + # Clip to ensure values are in [0, 1] + source_np_01 = np.clip(source_np_01, 0.0, 1.0) + reference_np_01 = np.clip(reference_np_01, 0.0, 1.0) + + corrected_frames_np_01 = [] + + for t in range(T): + source_frame = source_np_01[t] # (H, W, C) + reference_frame = reference_np_01[t] # (H, W, C) + frame_mask = mask_np[t] # (H, W) + + # Find mask borders and create distance maps + border_regions = get_border_sampling_regions(frame_mask, source_border_distance, reference_border_distance) + source_sample_region = border_regions['source_region'] # mask=1 side + reference_sample_region = border_regions['reference_region'] # mask=0 side + + # Sample pixels for color statistics + try: + source_stats = compute_color_stats(source_frame, source_sample_region) + reference_stats = compute_color_stats(reference_frame, reference_sample_region) + except ValueError as e: + print(f"Warning: Could not compute color statistics for frame {t}: {e}. Using original frame.") + corrected_frames_np_01.append(source_frame) + continue + + # Apply color correction to mask=1 area and handle mask=0 area based on copy_mode + corrected_frame = apply_color_correction_with_mask( + source_frame, frame_mask, source_stats, reference_stats, strength + ) + + # Handle mask=0 pixels based on copy_mode + if copy_mode == "reference": + corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference") + elif copy_mode == "source": + corrected_frame = apply_copy_with_mask(corrected_frame, source_frame, frame_mask, "source") + elif copy_mode == "progressive_blend": + # Apply progressive blending in mask=1 border area (source side) + corrected_frame = apply_progressive_blend_in_corrected_area( + corrected_frame, reference_frame, frame_mask, + border_regions['source_region'], border_regions['source_distances'], + border_regions['reference_region'], source_border_distance + ) + # Copy reference pixels to mask=0 area first + corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference") + # Then apply progressive blending in mask=0 border area (reference side) + corrected_frame = apply_progressive_blend_in_reference_area( + corrected_frame, source_frame, frame_mask, + border_regions['reference_region'], border_regions['reference_distances'], + reference_border_distance + ) + + corrected_frames_np_01.append(corrected_frame) + + corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0) + + # Convert back to [-1, 1] and return to tensor format + corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0 + corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0) + corrected_chunk_tensor = corrected_chunk_tensor.contiguous() + output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype) + + return output_tensor + + +def get_border_sampling_regions(mask, source_border_distance, reference_border_distance): + """ + Create regions for sampling near mask borders with separate distances for source and reference. + + Args: + mask: Binary mask (H, W) with 0s and 1s + source_border_distance: Distance from border to include in source sampling (mask=1 side) + reference_border_distance: Distance from border to include in reference sampling (mask=0 side) + + Returns: + Dict with sampling regions and distance maps for blending + """ + # Convert to boolean for safety + mask_bool = mask.astype(bool) + + # Distance from mask=0 regions (distance into mask=1 areas from border) + dist_from_mask0 = distance_transform_edt(mask_bool) + + # Distance from mask=1 regions (distance into mask=0 areas from border) + dist_from_mask1 = distance_transform_edt(~mask_bool) + + # Source region: mask=1 pixels within source_border_distance of mask=0 pixels + source_region = mask_bool & (dist_from_mask0 <= source_border_distance) + + # Reference region: mask=0 pixels within reference_border_distance of mask=1 pixels + reference_region = (~mask_bool) & (dist_from_mask1 <= reference_border_distance) + + return { + 'source_region': source_region, + 'reference_region': reference_region, + 'source_distances': dist_from_mask0, # Distance into mask=1 from border + 'reference_distances': dist_from_mask1 # Distance into mask=0 from border + } + + +def compute_color_stats(image, sample_region): + """ + Compute color statistics (mean and std) for Lab channels in the sampling region. + + Args: + image: RGB image (H, W, C) in range [0, 1] + sample_region: Boolean mask (H, W) indicating pixels to sample + + Returns: + Dict with 'mean' and 'std' for Lab components + """ + if not np.any(sample_region): + raise ValueError("No pixels in sampling region") + + # Convert to Lab + try: + image_lab = color.rgb2lab(image) + except ValueError as e: + raise ValueError(f"Could not convert image to Lab: {e}") + + # Extract pixels in sampling region + sampled_pixels = image_lab[sample_region] # (N, 3) where N is number of sampled pixels + + # Compute statistics for each Lab channel + stats = { + 'mean': np.mean(sampled_pixels, axis=0), # (3,) for L, a, b + 'std': np.std(sampled_pixels, axis=0) # (3,) for L, a, b + } + + return stats + + +def apply_color_correction_with_mask(source_frame, mask, source_stats, reference_stats, strength): + """ + Apply color correction to pixels where mask=1. + + Args: + source_frame: RGB image (H, W, C) in range [0, 1] + mask: Binary mask (H, W) + source_stats: Color statistics from source sampling region + reference_stats: Color statistics from reference sampling region + strength: Blending strength + + Returns: + Corrected RGB image (H, W, C) + """ + try: + source_lab = color.rgb2lab(source_frame) + except ValueError as e: + print(f"Warning: Could not convert source frame to Lab: {e}. Using original frame.") + return source_frame + + corrected_lab = source_lab.copy() + correction_region = (mask == 1) # Apply correction to mask=1 pixels + + # Apply color transfer to pixels where mask=1 + for c in range(3): # L, a, b channels + mean_src = source_stats['mean'][c] + std_src = source_stats['std'][c] + mean_ref = reference_stats['mean'][c] + std_ref = reference_stats['std'][c] + + if std_src == 0: + # Handle case where source channel has no variation + corrected_lab[correction_region, c] = mean_ref + else: + # Standard color transfer formula + corrected_lab[correction_region, c] = ( + (corrected_lab[correction_region, c] - mean_src) * (std_ref / std_src) + mean_ref + ) + + try: + fully_corrected_rgb = color.lab2rgb(corrected_lab) + except ValueError as e: + print(f"Warning: Could not convert corrected frame back to RGB: {e}. Using original frame.") + return source_frame + + # Clip to [0, 1] + fully_corrected_rgb = np.clip(fully_corrected_rgb, 0.0, 1.0) + + # Blend with original (only in correction region) + result = source_frame.copy() + result[correction_region] = ( + (1 - strength) * source_frame[correction_region] + + strength * fully_corrected_rgb[correction_region] + ) + + return result + + +def apply_progressive_blend_in_corrected_area(corrected_frame, reference_frame, mask, source_region, source_distances, reference_region, source_border_distance): + """ + Apply progressive blending in the corrected area (mask=1) near the border. + + Args: + corrected_frame: RGB image (H, W, C) - the color-corrected source frame + reference_frame: RGB image (H, W, C) - the reference frame + mask: Binary mask (H, W) + source_region: Boolean mask (H, W) indicating the source blending region (mask=1 near border) + source_distances: Distance map (H, W) into mask=1 area from mask=0 border + reference_region: Boolean mask (H, W) indicating the reference sampling region (mask=0 near border) + source_border_distance: Maximum distance for source blending + + Returns: + Blended RGB image (H, W, C) + + Notes: + - Each source pixel blends with its closest reference border pixel (for speed) + - At mask border: 60% source + 40% reference + - Deeper into mask=1 area: 100% corrected source + """ + result = corrected_frame.copy() + + # Blend in the source region (mask=1 pixels near border) + blend_region = source_region + + if np.any(blend_region): + # Find immediate border pixels (mask=0 pixels adjacent to mask=1 pixels) + # This is much faster than using the entire reference region + from scipy.ndimage import binary_dilation + + # Dilate mask=1 by 1 pixel, then find intersection with mask=0 + mask_1_dilated = binary_dilation(mask == 1, structure=np.ones((3, 3))) + border_pixels = (mask == 0) & mask_1_dilated + + if np.any(border_pixels): + # Find closest border pixel for each source pixel + source_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates + border_coords = np.column_stack(np.where(border_pixels)) # (M, 2) - much smaller set! + + # For each source pixel, find closest border pixel + from scipy.spatial.distance import cdist + distances_matrix = cdist(source_coords, border_coords, metric='euclidean') + closest_border_indices = np.argmin(distances_matrix, axis=1) + + # Normalize source distances for blending weights + min_distance_in_region = np.min(source_distances[blend_region]) + max_distance_in_region = np.max(source_distances[blend_region]) + + if max_distance_in_region > min_distance_in_region: + # Calculate blend weights: 0.4 at border (60% source + 40% reference), 0.0 at max distance (100% source) + source_dist_values = source_distances[blend_region] + normalized_distances = (source_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region) + blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% reference influence at border + + # Apply blending with closest border pixels + for i, (source_y, source_x) in enumerate(source_coords): + closest_border_idx = closest_border_indices[i] + border_y, border_x = border_coords[closest_border_idx] + + weight = blend_weights[i] + # Blend with closest border pixel + result[source_y, source_x] = ( + (1.0 - weight) * corrected_frame[source_y, source_x] + + weight * reference_frame[border_y, border_x] + ) + + return result + + +def apply_progressive_blend_in_reference_area(reference_frame, source_frame, mask, reference_region, reference_distances, reference_border_distance): + """ + Apply progressive blending in the reference area (mask=0) near the border. + + Args: + reference_frame: RGB image (H, W, C) - the reference frame with copied reference pixels + source_frame: RGB image (H, W, C) - the original source frame + mask: Binary mask (H, W) + reference_region: Boolean mask (H, W) indicating the reference blending region (mask=0 near border) + reference_distances: Distance map (H, W) into mask=0 area from mask=1 border + reference_border_distance: Maximum distance for reference blending + + Returns: + Blended RGB image (H, W, C) + + Notes: + - Each reference pixel blends with its closest source border pixel (for speed) + - At mask border: 60% reference + 40% source + - Deeper into mask=0 area: 100% reference + """ + result = reference_frame.copy() + + # Blend in the reference region (mask=0 pixels near border) + blend_region = reference_region + + if np.any(blend_region): + # Find immediate border pixels (mask=1 pixels adjacent to mask=0 pixels) + from scipy.ndimage import binary_dilation + + # Dilate mask=0 by 1 pixel, then find intersection with mask=1 + mask_0_dilated = binary_dilation(mask == 0, structure=np.ones((3, 3))) + source_border_pixels = (mask == 1) & mask_0_dilated + + if np.any(source_border_pixels): + # Find closest source border pixel for each reference pixel + reference_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates + source_border_coords = np.column_stack(np.where(source_border_pixels)) # (M, 2) + + # For each reference pixel, find closest source border pixel + from scipy.spatial.distance import cdist + distances_matrix = cdist(reference_coords, source_border_coords, metric='euclidean') + closest_source_indices = np.argmin(distances_matrix, axis=1) + + # Normalize reference distances for blending weights + min_distance_in_region = np.min(reference_distances[blend_region]) + max_distance_in_region = np.max(reference_distances[blend_region]) + + if max_distance_in_region > min_distance_in_region: + # Calculate blend weights: 0.4 at border (60% reference + 40% source), 0.0 at max distance (100% reference) + reference_dist_values = reference_distances[blend_region] + normalized_distances = (reference_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region) + blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% source influence at border + + # Apply blending with closest source border pixels + for i, (ref_y, ref_x) in enumerate(reference_coords): + closest_source_idx = closest_source_indices[i] + source_y, source_x = source_border_coords[closest_source_idx] + + weight = blend_weights[i] + # Blend: weight=0.4 means 60% reference + 40% source at border + result[ref_y, ref_x] = ( + (1.0 - weight) * reference_frame[ref_y, ref_x] + + weight * source_frame[source_y, source_x] + ) + + return result + + +def apply_copy_with_mask(source_frame, reference_frame, mask, copy_source): + """ + Copy pixels to mask=0 regions based on copy_source parameter. + + Args: + source_frame: RGB image (H, W, C) + reference_frame: RGB image (H, W, C) + mask: Binary mask (H, W) + copy_source: "reference" or "source" + + Returns: + Combined RGB image (H, W, C) + """ + result = source_frame.copy() + mask_0_region = (mask == 0) + + if copy_source == "reference": + result[mask_0_region] = reference_frame[mask_0_region] + # If "source", we keep the original source pixels (no change needed) + + return result \ No newline at end of file diff --git a/wan/multitalk/torch_utils.py b/models/wan/multitalk/torch_utils.py similarity index 100% rename from wan/multitalk/torch_utils.py rename to models/wan/multitalk/torch_utils.py diff --git a/wan/multitalk/wav2vec2.py b/models/wan/multitalk/wav2vec2.py similarity index 98% rename from wan/multitalk/wav2vec2.py rename to models/wan/multitalk/wav2vec2.py index 5ec9c2b..9ab590c 100644 --- a/wan/multitalk/wav2vec2.py +++ b/models/wan/multitalk/wav2vec2.py @@ -20,7 +20,7 @@ class Wav2Vec2Model(Wav2Vec2Model): output_hidden_states=None, return_dict=None, ): - self.config.output_attentions = True + # self.config.output_attentions = True output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py new file mode 100644 index 0000000..8af9458 --- /dev/null +++ b/models/wan/text2video fuse attempt.py @@ -0,0 +1,698 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial +from mmgp import offload +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm +from PIL import Image +import torchvision.transforms.functional as TF +import torch.nn.functional as F +from .distributed.fsdp import shard_model +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.modules.posemb_layers import get_rotary_pos_embed +from .utils.vace_preprocessor import VaceVideoProcessor + + +def optimized_scale(positive_flat, negative_flat): + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + + +class WanT2V: + + def __init__( + self, + config, + checkpoint_dir, + rank=0, + model_filename = None, + text_encoder_filename = None, + quantizeTransformer = False, + dtype = torch.bfloat16 + ): + self.device = torch.device(f"cuda") + self.config = config + self.rank = rank + self.dtype = dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=text_encoder_filename, + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn= None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + + + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating WanModel from {model_filename}") + from mmgp import offload + + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) + # offload.load_model_data(self.model, "recam.ckpt") + # self.model.cpu() + # offload.save_model(self.model, "recam.safetensors") + if self.dtype == torch.float16 and not "fp16" in model_filename: + self.model.to(self.dtype) + # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) + if self.dtype == torch.float16: + self.vae.model.to(self.dtype) + self.model.eval().requires_grad_(False) + + + self.sample_neg_prompt = config.sample_neg_prompt + + if "Vace" in model_filename: + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480*832, + max_area=480*832, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + self.adapt_vace_model() + + self.scheduler = FlowUniPCMultistepScheduler() + + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = self.vae.encode(frames, tile_size = tile_size) + else: + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = self.vae.encode(inactive, tile_size = tile_size) + reactive = self.vae.encode(reactive, tile_size = tile_size) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + else: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // self.vae_stride[0]) + height = 2 * (int(height) // (self.vae_stride[1] * 2)) + width = 2 * (int(width) // (self.vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, self.vae_stride[1], width, self.vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + self.vae_stride[1] * self.vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): + image_sizes = [] + trim_video = len(keep_frames) + + for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): + prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] + num_frames = total_frames - prepend_count + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) + # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) + src_video_shape = src_video[i].shape + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) + else: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) + src_video_shape = src_video[i].shape + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + image_sizes.append(src_video[i].shape[2:]) + for k, keep in enumerate(keep_frames): + if not keep: + src_video[i][:, k:k+1] = 0 + src_mask[i][:, k:k+1] = 1 + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, tile_size= 0 ): + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return self.vae.decode(trimed_zs, tile_size= tile_size) + + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + def generate(self, + input_prompt, + input_frames= None, + input_masks = None, + input_ref_images = None, + source_video=None, + target_camera=None, + context_scale=1.0, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + callback = None, + enable_RIFLEx = None, + VAE_tile_size = 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, + cfg_star_switch = True, + cfg_zero_step = 5, + ): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + frame_num = max(17, frame_num) # must match causal_block_size for value of 5 + frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) + num_frames = frame_num + addnoise_condition = 20 + causal_attention = True + fps = 16 + ar_step = 5 + + + + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if target_camera != None: + size = (source_video.shape[2], source_video.shape[1]) + source_video = source_video.to(dtype=self.dtype , device=self.device) + source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) + source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) + del source_video + # Process target camera (recammaster) + from wan.utils.cammmaster_tools import get_camera_embedding + cam_emb = get_camera_embedding(target_camera) + cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) + + if input_frames != None: + # vace context encode + input_frames = [u.to(self.device) for u in input_frames] + input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] + input_masks = [u.to(self.device) for u in input_masks] + + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + else: + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1]) + + context = [u.to(self.dtype) for u in context] + context_null = [u.to(self.dtype) for u in context_null] + + noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] + + # evaluation mode + + # if sample_solver == 'unipc': + # sample_scheduler = FlowUniPCMultistepScheduler( + # num_train_timesteps=self.num_train_timesteps, + # shift=1, + # use_dynamic_shifting=False) + # sample_scheduler.set_timesteps( + # sampling_steps, device=self.device, shift=shift) + # timesteps = sample_scheduler.timesteps + # elif sample_solver == 'dpm++': + # sample_scheduler = FlowDPMSolverMultistepScheduler( + # num_train_timesteps=self.num_train_timesteps, + # shift=1, + # use_dynamic_shifting=False) + # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + # timesteps, _ = retrieve_timesteps( + # sample_scheduler, + # device=self.device, + # sigmas=sampling_sigmas) + # else: + # raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + del noise + batch_size =len(latents) + if target_camera != None: + shape = list(latents[0].shape[1:]) + shape[0] *= 2 + freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) + else: + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) + # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback} + # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} + # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} + + i2v_extra_kwrags = {} + + if target_camera != None: + recam_dict = {'cam_emb': cam_emb} + i2v_extra_kwrags.update(recam_dict) + + if input_frames != None: + vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} + i2v_extra_kwrags.update(vace_dict) + + + latent_length = (num_frames - 1) // 4 + 1 + latent_height = height // 8 + latent_width = width // 8 + if ar_step == 0: + causal_block_size = 1 + fps_embeds = [fps] #* prompt_embeds[0].shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + init_timesteps = self.scheduler.timesteps + base_num_frames_iter = latent_length + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + + prefix_video = None + predix_video_latent_length = 0 + + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + init_timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + + updated_num_steps= len(step_matrix) + + if callback != None: + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + if self.model.enable_teacache: + self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) + # if callback != None: + # callback(-1, None, True) + + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + kwrags = { + "x" : torch.stack([latent_model_input[0]]), + "t" : timestep, + "freqs" :freqs, + "fps" : fps_embeds, + "causal_block_size" : causal_block_size, + "causal_attention" : causal_attention, + "callback" : callback, + "pipeline" : self, + "current_step" : i, + } + kwrags.update(i2v_extra_kwrags) + + if not self.do_classifier_free_guidance: + noise_pred = self.model( + context=context, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred= noise_pred.to(torch.float32) + else: + if joint_pass: + noise_pred_cond, noise_pred_uncond = self.model( + context=context, + context2=context_null, + **kwrags, + ) + if self._interrupt: + return None + else: + noise_pred_cond = self.model( + context=context, + **kwrags, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( + context=context_null, + )[0] + if self._interrupt: + return None + noise_pred_cond= noise_pred_cond.to(torch.float32) + noise_pred_uncond= noise_pred_uncond.to(torch.float32) + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=seed_g, + )[0] + sample_schedulers_counter[idx] += 1 + if callback is not None: + callback(i, latents[0].squeeze(0), False) + + # for i, t in enumerate(tqdm(timesteps)): + # if target_camera != None: + # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] + # else: + # latent_model_input = latents + # slg_layers_local = None + # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): + # slg_layers_local = slg_layers + # timestep = [t] + # offload.set_step_no_for_lora(self.model, i) + # timestep = torch.stack(timestep) + + # if joint_pass: + # noise_pred_cond, noise_pred_uncond = self.model( + # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) + # if self._interrupt: + # return None + # else: + # noise_pred_cond = self.model( + # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] + # if self._interrupt: + # return None + # noise_pred_uncond = self.model( + # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] + # if self._interrupt: + # return None + + # # del latent_model_input + + # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + # noise_pred_text = noise_pred_cond + # if cfg_star_switch: + # positive_flat = noise_pred_text.view(batch_size, -1) + # negative_flat = noise_pred_uncond.view(batch_size, -1) + + # alpha = optimized_scale(positive_flat,negative_flat) + # alpha = alpha.view(batch_size, 1, 1, 1) + + # if (i <= cfg_zero_step): + # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... + # else: + # noise_pred_uncond *= alpha + # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) + # del noise_pred_uncond + + # temp_x0 = sample_scheduler.step( + # noise_pred[:, :target_shape[1]].unsqueeze(0), + # t, + # latents[0].unsqueeze(0), + # return_dict=False, + # generator=seed_g)[0] + # latents = [temp_x0.squeeze(0)] + # del temp_x0 + + # if callback is not None: + # callback(i, latents[0], False) + + x0 = latents + + if input_frames == None: + videos = self.vae.decode(x0, VAE_tile_size) + else: + videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) + + del latents + del sample_scheduler + + return videos[0] if self.rank == 0 else None + + def adapt_vace_model(self): + model = self.model + modules_dict= { k: m for k, m in model.named_modules()} + for model_layer, vace_layer in model.vace_layers_mapping.items(): + module = modules_dict[f"vace_blocks.{vace_layer}"] + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "vace", module ) + delattr(model, "vace_blocks") + + \ No newline at end of file diff --git a/wan/trajectory_editor/app.py b/models/wan/trajectory_editor/app.py similarity index 100% rename from wan/trajectory_editor/app.py rename to models/wan/trajectory_editor/app.py diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py new file mode 100644 index 0000000..c3ad012 --- /dev/null +++ b/models/wan/wan_handler.py @@ -0,0 +1,311 @@ +import torch +import numpy as np + +def test_class_i2v(base_model_type): + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + +def test_class_1_3B(base_model_type): + return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] + +def test_multitalk(base_model_type): + return base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk", "infinitetalk"] + +def test_standin(base_model_type): + return base_model_type in ["standin", "vace_standin_14B"] + +class family_handler(): + + @staticmethod + def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): + i2v = test_class_i2v(base_model_type) + + resolution = inputs["resolution"] + width, height = resolution.split("x") + pixels = int(width) * int(height) + + if cache_type == "mag": + skip_steps_cache.update({ + "magcache_thresh" : 0, + "magcache_K" : 2, + }) + if base_model_type in ["t2v"] and "URLs2" in model_def: + def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] + elif base_model_type in ["i2v_2_2"]: + def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] + elif base_model_type in ["ti2v_2_2"]: + if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v + def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] + else: # i2v + def_mag_ratios = [0.99512, 0.99559, 0.99559, 0.99561, 0.99595, 0.99577, 0.99512, 0.99512, 0.99546, 0.99534, 0.99543, 0.99531, 0.99496, 0.99491, 0.99504, 0.99499, 0.99444, 0.99449, 0.99481, 0.99481, 0.99435, 0.99435, 0.9943, 0.99431, 0.99411, 0.99406, 0.99373, 0.99376, 0.99413, 0.99405, 0.99363, 0.99359, 0.99335, 0.99331, 0.99244, 0.99243, 0.99229, 0.99229, 0.99239, 0.99236, 0.99163, 0.9916, 0.99149, 0.99151, 0.99191, 0.99192, 0.9898, 0.98981, 0.9899, 0.98987, 0.98849, 0.98849, 0.98846, 0.98846, 0.98861, 0.98861, 0.9874, 0.98738, 0.98588, 0.98589, 0.98539, 0.98534, 0.98444, 0.98439, 0.9831, 0.98309, 0.98119, 0.98118, 0.98001, 0.98, 0.97862, 0.97859, 0.97555, 0.97558, 0.97392, 0.97388, 0.97152, 0.97145, 0.96871, 0.9687, 0.96435, 0.96434, 0.96129, 0.96127, 0.95639, 0.95638, 0.95176, 0.95175, 0.94446, 0.94452, 0.93972, 0.93974, 0.93575, 0.9359, 0.93537, 0.93552, 0.96655, 0.96616] + elif test_class_1_3B(base_model_type): #text 1.3B + def_mag_ratios = [1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]#**(0.5)# In our papaer, we utilize the sqrt to smooth the ratio, which has little impact on the performance and can be deleted. + elif i2v: + if pixels >= 1280*720: + def_mag_ratios = [0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768] + else: + def_mag_ratios = [0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616] + else: # text 14B + def_mag_ratios = [1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189] + skip_steps_cache.def_mag_ratios = def_mag_ratios + else: + if i2v: + if pixels >= 1280*720: + coefficients= [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + else: + coefficients= [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + else: + if test_class_1_3B(base_model_type): + coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + else: + coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + skip_steps_cache.coefficients = coefficients + + @staticmethod + def get_wan_text_encoder_filename(text_encoder_quantization): + text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" + if text_encoder_quantization =="int8": + text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8") + return text_encoder_filename + + @staticmethod + def query_model_def(base_model_type, model_def): + extra_model_def = {} + if "URLs2" in model_def: + extra_model_def["no_steps_skipping"] = True + i2v = test_class_i2v(base_model_type) + extra_model_def["i2v_class"] = i2v + extra_model_def["multitalk_class"] = test_multitalk(base_model_type) + extra_model_def["standin_class"] = test_standin(base_model_type) + vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] + extra_model_def["vace_class"] = vace_class + + if test_multitalk(base_model_type): + fps = 25 + elif base_model_type in ["fantasy"]: + fps = 23 + elif base_model_type in ["ti2v_2_2"]: + fps = 24 + else: + fps = 16 + extra_model_def["fps"] =fps + multiple_submodels = "URLs2" in model_def + if vace_class: + frames_minimum, frames_steps = 17, 4 + else: + frames_minimum, frames_steps = 5, 4 + extra_model_def.update({ + "frames_minimum" : frames_minimum, + "frames_steps" : frames_steps, + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "multiple_submodels" : multiple_submodels, + "guidance_max_phases" : 3, + "skip_layer_guidance" : True, + "cfg_zero" : True, + "cfg_star" : True, + "adaptive_projected_guidance" : True, + "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "mag_cache" : True, + "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], + "sample_solvers":[ + ("unipc", "unipc"), + ("euler", "euler"), + ("dpm++", "dpm++"), + ("flowmatch causvid", "causvid"), ] + }) + if base_model_type in ["infinitetalk"]: + extra_model_def["no_background_removal"] = True + # extra_model_def["at_least_one_image_ref_needed"] = True + + return extra_model_def + + @staticmethod + def query_supported_types(): + return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", + "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "recam_1.3B", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + + + @staticmethod + def query_family_maps(): + + models_eqv_map = { + "flf2v_720p" : "i2v", + "t2v_1.3B" : "t2v", + } + + models_comp_map = { + "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"], + "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], + "i2v_2_2" : ["i2v_2_2_multitalk"], + "fantasy": ["multitalk"], + } + return models_eqv_map, models_comp_map + + @staticmethod + def query_model_family(): + return "wan" + + @staticmethod + def query_family_infos(): + return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") } + + @staticmethod + def get_vae_block_size(base_model_type): + return 32 if base_model_type == "ti2v_2_2" else 16 + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) + return latent_rgb_factors, latent_rgb_factors_bias + + @staticmethod + def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): + text_encoder_filename = family_handler.get_wan_text_encoder_filename(text_encoder_quantization) + + download_def = [{ + "repoId" : "DeepBeepMeep/Wan2.1", + "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], + "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] + }] + + if base_model_type == "ti2v_2_2": + download_def += [ { + "repoId" : "DeepBeepMeep/Wan2.2", + "sourceFolderList" : [""], + "fileList" : [ [ "Wan2.2_VAE.safetensors" ] ] + }] + + return download_def + + + @staticmethod + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + from .configs import WAN_CONFIGS + + if test_class_i2v(base_model_type): + cfg = WAN_CONFIGS['i2v-14B'] + else: + cfg = WAN_CONFIGS['t2v-14B'] + # cfg = WAN_CONFIGS['t2v-1.3B'] + from . import WanAny2V + wan_model = WanAny2V( + config=cfg, + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + model_def = model_def, + base_model_type=base_model_type, + text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } + if hasattr(wan_model,"model2") and wan_model.model2 is not None: + pipe["transformer2"] = wan_model.model2 + if hasattr(wan_model, "clip"): + pipe["text_encoder_2"] = wan_model.clip.model + return wan_model, pipe + + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if ui_defaults.get("sample_solver", "") == "": + ui_defaults["sample_solver"] = "unipc" + + if settings_version < 2.24: + if (model_def.get("multiple_submodels", False) or ui_defaults.get("switch_threshold", 0) > 0) and ui_defaults.get("guidance_phases",0)<2: + ui_defaults["guidance_phases"] = 2 + + if settings_version == 2.24 and ui_defaults.get("guidance_phases",0) ==2: + mult = model_def.get("loras_multipliers","") + if len(mult)> 1 and len(mult[0].split(";"))==3: ui_defaults["guidance_phases"] = 3 + + if settings_version < 2.27: + if base_model_type in "infinitetalk": + guidance_scale = ui_defaults.get("guidance_scale", None) + if guidance_scale == 1: + ui_defaults["audio_guidance_scale"]= 1 + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "I" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("KI", "QKI") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod + def update_default_settings(base_model_type, model_def, ui_defaults): + ui_defaults.update({ + "sample_solver": "unipc", + }) + if base_model_type in ["fantasy"]: + ui_defaults.update({ + "audio_guidance_scale": 5.0, + "sliding_window_size": 1, + }) + + elif base_model_type in ["multitalk"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "sliding_window_discard_last_frames" : 4, + "sample_solver" : "euler", + "adaptive_switch" : 1, + }) + + elif base_model_type in ["infinitetalk"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "sliding_window_overlap" : 9, + "sample_solver" : "euler", + "video_prompt_type": "QKI", + "remove_background_images_ref" : 0, + "adaptive_switch" : 1, + }) + + elif base_model_type in ["standin"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "sliding_window_overlap" : 9, + "video_prompt_type": "I", + "remove_background_images_ref" : 1, + }) + elif base_model_type in ["phantom_1.3B", "phantom_14B"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "remove_background_images_ref": 1, + "video_prompt_type": "I", + # "resolution": "1280x720" + }) + + elif base_model_type in ["vace_14B", "vace_multitalk_14B"]: + ui_defaults.update({ + "sliding_window_discard_last_frames": 0, + }) + + elif base_model_type in ["ti2v_2_2"]: + ui_defaults.update({ + "image_prompt_type": "T", + }) + + if test_multitalk(base_model_type): + ui_defaults["audio_guidance_scale"] = 4 + + if model_def.get("multiple_submodels", False): + ui_defaults["guidance_phases"] = 2 + + @staticmethod + def validate_generative_settings(base_model_type, model_def, inputs): + if base_model_type in ["infinitetalk"]: + video_source = inputs["video_source"] + image_refs = inputs["image_refs"] + video_prompt_type = inputs["video_prompt_type"] + image_prompt_type = inputs["image_prompt_type"] + if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None: + video_prompt_type = video_prompt_type.replace("I", "").replace("K","") + inputs["video_prompt_type"] = video_prompt_type diff --git a/postprocessing/mmaudio/data/av_utils.py b/postprocessing/mmaudio/data/av_utils.py index 6fd0b1d..7866d6c 100644 --- a/postprocessing/mmaudio/data/av_utils.py +++ b/postprocessing/mmaudio/data/av_utils.py @@ -131,12 +131,14 @@ from pathlib import Path import torch def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int): - from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = Path(f.name) temp_path_str= str(temp_path) import torchaudio torchaudio.save(temp_path_str, audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate) + combine_video_with_audio_tracks(video_path, [temp_path_str], output_path ) temp_path.unlink(missing_ok=True) diff --git a/postprocessing/mmaudio/mmaudio.py b/postprocessing/mmaudio/mmaudio.py index e153b09..c4f8ce6 100644 --- a/postprocessing/mmaudio/mmaudio.py +++ b/postprocessing/mmaudio/mmaudio.py @@ -76,7 +76,7 @@ def get_model(persistent_models = False, verboseLevel = 1) -> tuple[MMAudio, Fea @torch.inference_mode() def video_to_audio(video, prompt: str, negative_prompt: str, seed: int, num_steps: int, - cfg_strength: float, duration: float, video_save_path , persistent_models = False, verboseLevel = 1): + cfg_strength: float, duration: float, save_path , persistent_models = False, audio_file_only = False, verboseLevel = 1): global device @@ -110,11 +110,17 @@ def video_to_audio(video, prompt: str, negative_prompt: str, seed: int, num_step ) audio = audios.float().cpu()[0] - make_video(video, video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) + + if audio_file_only: + import torchaudio + torchaudio.save(save_path, audio.unsqueeze(0) if audio.dim() == 1 else audio, seq_cfg.sampling_rate) + else: + make_video(video, video_info, save_path, audio, sampling_rate=seq_cfg.sampling_rate) + offloadobj.unload_all() if not persistent_models: offloadobj.release() torch.cuda.empty_cache() gc.collect() - return video_save_path + return save_path diff --git a/preprocessing/canny.py b/preprocessing/canny.py new file mode 100644 index 0000000..df89dde --- /dev/null +++ b/preprocessing/canny.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from PIL import Image + + +norm_layer = nn.InstanceNorm2d + +def convert_to_torch(image): + if isinstance(image, Image.Image): + image = torch.from_numpy(np.array(image)).float() + elif isinstance(image, torch.Tensor): + image = image.clone() + elif isinstance(image, np.ndarray): + image = torch.from_numpy(image.copy()).float() + else: + raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' + return image + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class ContourInference(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(ContourInference, self).__init__() + + # Initial convolution block + model0 = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) + ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features * 2 + for _ in range(2): + model1 += [ + nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) + ] + in_features = out_features + out_features = in_features * 2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features // 2 + for _ in range(2): + model3 += [ + nn.ConvTranspose2d(in_features, + out_features, + 3, + stride=2, + padding=1, + output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) + ] + in_features = out_features + out_features = in_features // 2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class CannyAnnotator: + def __init__(self, cfg, device=None): + input_nc = cfg.get('INPUT_NC', 3) + output_nc = cfg.get('OUTPUT_NC', 1) + n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3) + sigmoid = cfg.get('SIGMOID', True) + pretrained_model = cfg['PRETRAINED_MODEL'] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = ContourInference(input_nc, output_nc, n_residual_blocks, + sigmoid) + self.model.load_state_dict(torch.load(pretrained_model, weights_only=True)) + self.model = self.model.eval().requires_grad_(False).to(self.device) + + @torch.no_grad() + @torch.inference_mode() + @torch.autocast('cuda', enabled=False) + def forward(self, image): + is_batch = False if len(image.shape) == 3 else True + image = convert_to_torch(image) + if len(image.shape) == 3: + image = rearrange(image, 'h w c -> 1 c h w') + image = image.float().div(255).to(self.device) + contour_map = self.model(image) + contour_map = (contour_map.squeeze(dim=1) * 255.0).clip( + 0, 255).cpu().numpy().astype(np.uint8) + contour_map = contour_map[..., None].repeat(3, -1) + contour_map = 255 - contour_map #.where( image >= 127.5,0,1) + contour_map[ contour_map > 127.5] = 255 + contour_map[ contour_map <= 127.5] = 0 + if not is_batch: + contour_map = contour_map.squeeze() + return contour_map + + +class CannyVideoAnnotator(CannyAnnotator): + def forward(self, frames): + ret_frames = [] + for frame in frames: + anno_frame = super().forward(np.array(frame)) + ret_frames.append(anno_frame) + return ret_frames \ No newline at end of file diff --git a/preprocessing/face_preprocessor.py b/preprocessing/face_preprocessor.py new file mode 100644 index 0000000..bdef48b --- /dev/null +++ b/preprocessing/face_preprocessor.py @@ -0,0 +1,259 @@ +import os +import cv2 +import requests +import torch +import numpy as np +import PIL.Image as Image +import PIL.ImageOps +# from insightface.app import FaceAnalysis +# from facexlib.parsing import init_parsing_model +from torchvision.transforms.functional import normalize +from typing import Union, Optional +from models.hyvideo.data_kits.face_align import AlignImage + + +def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor: + if bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.astype(np.float32) / 255.0 + img = np.transpose(img, (2, 0, 1)) + return torch.from_numpy(img) + + +def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray: + h, w, _ = img.shape + if h == w: + return img + + if h > w: + pad_size = (h - w) // 2 + padded_img = cv2.copyMakeBorder( + img, + 0, + 0, + pad_size, + h - w - pad_size, + cv2.BORDER_CONSTANT, + value=[pad_color] * 3, + ) + else: + pad_size = (w - h) // 2 + padded_img = cv2.copyMakeBorder( + img, + pad_size, + w - h - pad_size, + 0, + 0, + cv2.BORDER_CONSTANT, + value=[pad_color] * 3, + ) + + return padded_img + + +class FaceProcessor: + def __init__(self): + self.align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") + self.align_instance.facedet.model.to("cpu") + + + def process( + self, + image: Union[str, PIL.Image.Image], + resize_to: int = 512, + border_thresh: int = 10, + face_crop_scale: float = 1.5, + remove_bg= False, + # area=1.25 + ) -> PIL.Image.Image: + + image_pil = PIL.ImageOps.exif_transpose(image).convert("RGB") + w, h = image_pil.size + self.align_instance.facedet.model.to("cuda") + _, _, bboxes_list = self.align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) + self.align_instance.facedet.model.to("cpu") + + try: + bboxSrc = bboxes_list[0] + except: + bboxSrc = [0, 0, w, h] + x1, y1, ww, hh = bboxSrc + x2, y2 = x1 + ww, y1 + hh + # ww, hh = (x2-x1) * area, (y2-y1) * area + # center = [(x2+x1)//2, (y2+y1)//2] + # x1 = max(center[0] - ww//2, 0) + # y1 = max(center[1] - hh//2, 0) + # x2 = min(center[0] + ww//2, w) + # y2 = min(center[1] + hh//2, h) + + frame = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) + h, w, _ = frame.shape + image_to_process = None + + is_close_to_border = ( + x1 <= border_thresh + and y1 <= border_thresh + and x2 >= w - border_thresh + and y2 >= h - border_thresh + ) + + if is_close_to_border: + # print( + # "[Info] Face is close to border, padding original image to square." + # ) + image_to_process = _pad_to_square(frame, pad_color=255) + else: + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + side = int(max(x2 - x1, y2 - y1) * face_crop_scale) + half = side // 2 + + left = int(max(cx - half, 0)) + top = int(max(cy - half, 0)) + right = int(min(cx + half, w)) + bottom = int(min(cy + half, h)) + + cropped_face = frame[top:bottom, left:right] + image_to_process = _pad_to_square(cropped_face, pad_color=255) + + image_resized = cv2.resize( + image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_LANCZOS4 # .INTER_AREA + ) + + face_tensor = _img2tensor(image_resized).to("cpu") + + from shared.utils.utils import remove_background, convert_tensor_to_image + if remove_bg: + face_tensor = remove_background(face_tensor) + img_out = Image.fromarray(face_tensor.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + return img_out + + +# class FaceProcessor2: +# def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None): +# if device is None: +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# else: +# self.device = device + +# providers = ( +# ["CUDAExecutionProvider"] +# if self.device.type == "cuda" +# else ["CPUExecutionProvider"] +# ) +# self.app = FaceAnalysis( +# name="antelopev2", root=antelopv2_path, providers=providers +# ) +# self.app.prepare(ctx_id=0, det_size=(640, 640)) + +# self.parsing_model = init_parsing_model( +# model_name="bisenet", device=self.device +# ) +# self.parsing_model.eval() + +# print("FaceProcessor initialized successfully.") + +# def process( +# self, +# image: Union[str, PIL.Image.Image], +# resize_to: int = 512, +# border_thresh: int = 10, +# face_crop_scale: float = 1.5, +# extra_input: bool = False, +# ) -> PIL.Image.Image: +# if isinstance(image, str): +# if image.startswith("http://") or image.startswith("https://"): +# image = PIL.Image.open(requests.get(image, stream=True, timeout=10).raw) +# elif os.path.isfile(image): +# image = PIL.Image.open(image) +# else: +# raise ValueError( +# f"Input string is not a valid URL or file path: {image}" +# ) +# elif not isinstance(image, PIL.Image.Image): +# raise TypeError( +# "Input must be a file path, a URL, or a PIL.Image.Image object." +# ) + +# image = PIL.ImageOps.exif_transpose(image).convert("RGB") + +# frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + +# faces = self.app.get(frame) +# h, w, _ = frame.shape +# image_to_process = None + +# if not faces: +# print( +# "[Warning] No face detected. Using the whole image, padded to square." +# ) +# image_to_process = _pad_to_square(frame, pad_color=255) +# else: +# largest_face = max( +# faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]) +# ) +# x1, y1, x2, y2 = map(int, largest_face.bbox) + +# is_close_to_border = ( +# x1 <= border_thresh +# and y1 <= border_thresh +# and x2 >= w - border_thresh +# and y2 >= h - border_thresh +# ) + +# if is_close_to_border: +# print( +# "[Info] Face is close to border, padding original image to square." +# ) +# image_to_process = _pad_to_square(frame, pad_color=255) +# else: +# cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 +# side = int(max(x2 - x1, y2 - y1) * face_crop_scale) +# half = side // 2 + +# left = max(cx - half, 0) +# top = max(cy - half, 0) +# right = min(cx + half, w) +# bottom = min(cy + half, h) + +# cropped_face = frame[top:bottom, left:right] +# image_to_process = _pad_to_square(cropped_face, pad_color=255) + +# image_resized = cv2.resize( +# image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_AREA +# ) + +# face_tensor = ( +# _img2tensor(image_resized, bgr2rgb=True).unsqueeze(0).to(self.device) +# ) +# with torch.no_grad(): +# normalized_face = normalize(face_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) +# parsing_out = self.parsing_model(normalized_face)[0] +# parsing_mask = parsing_out.argmax(dim=1, keepdim=True) + +# background_mask_np = (parsing_mask.squeeze().cpu().numpy() == 0).astype( +# np.uint8 +# ) +# white_background = np.ones_like(image_resized, dtype=np.uint8) * 255 +# mask_3channel = cv2.cvtColor(background_mask_np * 255, cv2.COLOR_GRAY2BGR) +# result_img_bgr = np.where(mask_3channel == 255, white_background, image_resized) +# result_img_rgb = cv2.cvtColor(result_img_bgr, cv2.COLOR_BGR2RGB) +# img_white_bg = PIL.Image.fromarray(result_img_rgb) +# if extra_input: +# # 2. Create image with transparent background (new logic) +# # Create an alpha channel: 255 for foreground (not background), 0 for background +# alpha_channel = (parsing_mask.squeeze().cpu().numpy() != 0).astype( +# np.uint8 +# ) * 255 + +# # Convert the resized BGR image to RGB +# image_resized_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB) + +# # Stack RGB channels with the new alpha channel +# rgba_image = np.dstack((image_resized_rgb, alpha_channel)) + +# # Create PIL image from the RGBA numpy array +# img_transparent_bg = PIL.Image.fromarray(rgba_image, "RGBA") + +# return img_white_bg, img_transparent_bg +# else: +# return img_white_bg diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index c8ea190..151d9be 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -10,16 +10,18 @@ from PIL import Image import cv2 import torch +import torch.nn.functional as F import numpy as np import gradio as gr from .tools.painter import mask_painter from .tools.interact_tools import SamControler from .tools.misc import get_device from .tools.download_util import load_file_from_url - +from segment_anything.modeling.image_encoder import window_partition, window_unpartition, get_rel_pos, Block as image_encoder_block from .utils.get_default_model import get_matanyone_model from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone +from shared.utils.audio_video import save_video, save_image arg_device = "cuda" arg_sam_model_type="vit_h" @@ -27,7 +29,9 @@ arg_mask_save = False model_loaded = False model = None matanyone_model = None - +model_in_GPU = False +matanyone_in_GPU = False +bfloat16_supported = False # SAM generator class MaskGenerator(): def __init__(self, sam_checkpoint, device): @@ -65,7 +69,10 @@ def get_frames_from_image(image_input, image_state): Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ - load_sam() + + if image_input is None: + gr.Info("Please select an Image file") + return [gr.update()] * 17 user_name = time.time() frames = [image_input] * 2 # hardcode: mimic a video with 2 frames @@ -83,17 +90,21 @@ def get_frames_from_image(image_input, image_state): "fps": None } image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) + set_image_encoder_patch() + select_SAM() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(image_state["origin_images"][0]) + torch.cuda.empty_cache() return image_state, image_info, image_state["origin_images"][0], \ gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ - gr.update(visible=True), gr.update(visible=True), \ - gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \ + gr.update(visible=True), gr.update(visible=False), \ + gr.update(visible=False), gr.update(value="", visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) + # extract frames from upload video def get_frames_from_video(video_input, video_state): """ @@ -103,8 +114,9 @@ def get_frames_from_video(video_input, video_state): Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ - - load_sam() + if video_input is None: + gr.Info("Please select a Video file") + return [gr.update()] * 18 while model == None: time.sleep(1) @@ -163,8 +175,11 @@ def get_frames_from_video(video_input, video_state): "audio": audio_path } video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) + set_image_encoder_patch() + select_SAM() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + torch.cuda.empty_cache() return video_state, video_info, video_state["origin_images"][0], \ gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ @@ -203,6 +218,98 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state): return video_state["painted_images"][track_pause_number_slider],interactive_state + +def patched_forward(self, x: torch.Tensor) -> torch.Tensor: + def split_mlp(mlp, x, divide = 4): + x_shape = x.shape + x = x.view(-1, x.shape[-1]) + chunk_size = int(x.shape[0]/divide) + x_chunks = torch.split(x, chunk_size) + for i, x_chunk in enumerate(x_chunks): + mlp_chunk = mlp.lin1(x_chunk) + mlp_chunk = mlp.act(mlp_chunk) + x_chunk[...] = mlp.lin2(mlp_chunk) + return x.reshape(x_shape) + + def get_decomposed_rel_pos( q, rel_pos_h, rel_pos_w, q_size, k_size) -> torch.Tensor: + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + attn = torch.zeros(B, q_h, q_w, k_h, k_w, dtype=q.dtype, device=q.device) + attn += rel_h[:, :, :, :, None] + attn += rel_w[:, :, :, None, :] + return attn.view(B, q_h * q_w, k_h * k_w) + + def pay_attention(self, x: torch.Tensor, split_heads = 1) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + + if not bfloat16_supported: qkv = qkv.to(torch.float16) + + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + if split_heads == 1: + attn_mask = None + if self.use_rel_pos: + attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W)) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale) + else: + chunk_size = self.num_heads // split_heads + x = torch.empty_like(q) + q_chunks = torch.split(q, chunk_size) + k_chunks = torch.split(k, chunk_size) + v_chunks = torch.split(v, chunk_size) + x_chunks = torch.split(x, chunk_size) + for x_chunk, q_chunk, k_chunk, v_chunk in zip(x_chunks, q_chunks, k_chunks, v_chunks): + attn_mask = None + if self.use_rel_pos: + attn_mask = get_decomposed_rel_pos(q_chunk, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W)) + x_chunk[...] = F.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk, attn_mask=attn_mask, scale=self.scale) + del x_chunk, q_chunk, k_chunk, v_chunk + del q, k, v, attn_mask + x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + if not bfloat16_supported: x = x.to(torch.bfloat16) + + return self.proj(x) + + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + x_shape = x.shape + + if x_shape[0] > 10: + chunk_size = int(x.shape[0]/4) + 1 + x_chunks = torch.split(x, chunk_size) + for i, x_chunk in enumerate(x_chunks): + x_chunk[...] = pay_attention(self.attn,x_chunk) + else: + x = pay_attention(self.attn,x, 4) + + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x += shortcut + shortcut[...] = self.norm2(x) + # x += self.mlp(shortcut) + x += split_mlp(self.mlp, shortcut) + + return x + +def set_image_encoder_patch(): + if not hasattr(image_encoder_block, "patched"): #and False + image_encoder_block.forward = patched_forward + image_encoder_block.patched = True + # use sam to get the mask def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): # """ @@ -217,10 +324,13 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr else: coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) interactive_state["negative_click_times"] += 1 - + + select_SAM() # prompt for sam model + set_image_encoder_patch() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) + torch.cuda.empty_cache() prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( @@ -233,6 +343,7 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image + torch.cuda.empty_cache() return painted_image, video_state, interactive_state def add_multi_mask(video_state, interactive_state, mask_dropdown): @@ -267,17 +378,18 @@ def show_mask(video_state, interactive_state, mask_dropdown): return select_frame -def save_video(frames, output_path, fps): +# def save_video(frames, output_path, fps): - writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) - for frame in frames: - writer.append_data(frame) - writer.close() +# writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) +# for frame in frames: +# writer.append_data(frame) +# writer.close() - return output_path +# return output_path def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) + if len(rows) == 0 or len(cols) == 0: return [] xmin = min(cols) xmax = max(cols) + 1 ymin = min(rows) @@ -313,7 +425,9 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 + select_matanyone() foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter) + torch.cuda.empty_cache() foreground_mat = False @@ -344,13 +458,18 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si bbox_info = mask_to_xyxy_box(alpha_output) h = alpha_output.shape[0] w = alpha_output.shape[1] - bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] - bbox_info = ":".join(bbox_info) + if len(bbox_info) == 0: + bbox_info = "" + else: + bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] + bbox_info = ":".join(bbox_info) alpha_output = Image.fromarray(alpha_output) - return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True) + # return gr.update(value=foreground_output, visible= True), gr.update(value=alpha_output, visible= True), gr.update(value=bbox_info, visible= True), gr.update(visible=True), gr.update(visible=True) + + return foreground_output, alpha_output, gr.update(visible = True), gr.update(visible = True), gr.update(value=bbox_info, visible= True), gr.update(visible=True), gr.update(visible=True) # video matting -def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): +def video_matting(video_state,video_input, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) # if interactive_state["track_end_number"]: # following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] @@ -376,17 +495,25 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 + select_matanyone() foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size) + torch.cuda.empty_cache() output_frames = [] foreground_mat = matting_type == "Foreground" + new_alpha = [] if not foreground_mat: - new_alpha = [] for frame_alpha in alpha: frame_temp = frame_alpha.copy() frame_alpha[frame_temp > 127] = 0 frame_alpha[frame_temp <= 127] = 255 new_alpha.append(frame_alpha) - alpha = new_alpha + else: + for frame_alpha in alpha: + frame_alpha[frame_alpha > 127] = 255 + frame_alpha[frame_alpha <= 127] = 0 + new_alpha.append(frame_alpha) + alpha = new_alpha + # for frame_origin, frame_alpha in zip(following_frames, alpha): # if foreground_mat: # frame_alpha[frame_alpha > 127] = 255 @@ -408,10 +535,21 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask file_name= video_state["video_name"] file_name = ".".join(file_name.split(".")[:-1]) - foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps) - # foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video - alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps) - # alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video + + from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) + output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" + output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" + if len(source_audio_tracks) == 0: + foreground_output = save_video(foreground,output_fg_path , fps=fps, codec_type= video_output_codec) + else: + foreground_output_tmp = save_video(foreground, output_fg_temp_path , fps=fps, codec_type= video_output_codec) + combine_video_with_audio_tracks(output_fg_temp_path, source_audio_tracks, output_fg_path, audio_metadata=audio_metadata) + cleanup_temp_audio_files(source_audio_tracks) + os.remove(foreground_output_tmp) + foreground_output = output_fg_path + + alpha_output = save_video(alpha, "./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps, codec_type= video_output_codec) return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) @@ -494,21 +632,42 @@ def restart(): gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) -def load_sam(): - global model_loaded - global model - global matanyone_model - model.samcontroler.sam_controler.model.to(arg_device) +# def load_sam(): +# global model_loaded +# global model +# model.samcontroler.sam_controler.model.to(arg_device) + +# global matanyone_model +# matanyone_model.to(arg_device) + + +def select_matanyone(): + global matanyone_in_GPU, model_in_GPU + if matanyone_in_GPU: return + model.samcontroler.sam_controler.model.to("cpu") + model_in_GPU = False + torch.cuda.empty_cache() matanyone_model.to(arg_device) + matanyone_in_GPU = True + +def select_SAM(): + global matanyone_in_GPU, model_in_GPU + if model_in_GPU: return + matanyone_model.to("cpu") + matanyone_in_GPU = False + torch.cuda.empty_cache() + model.samcontroler.sam_controler.model.to(arg_device) + model_in_GPU = True def load_unload_models(selected): global model_loaded global model - global matanyone_model + global matanyone_model, matanyone_processor, matanyone_in_GPU , model_in_GPU, bfloat16_supported if selected: # print("Matanyone Tab Selected") if model_loaded: - load_sam() + pass + # load_sam() else: # args, defined in track_anything.py sam_checkpoint_url_dict = { @@ -526,21 +685,33 @@ def load_unload_models(selected): transfer_stream = torch.cuda.Stream() with torch.cuda.stream(transfer_stream): # initialize sams - model = MaskGenerator(sam_checkpoint, arg_device) + major, minor = torch.cuda.get_device_capability(arg_device) + if major < 8: + bfloat16_supported = False + else: + bfloat16_supported = True + + model = MaskGenerator(sam_checkpoint, "cpu") + model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device) + model_in_GPU = True from .matanyone.model.matanyone import MatAnyone matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) - matanyone_model = matanyone_model.to(arg_device).eval() + matanyone_model = matanyone_model.to("cpu").eval() + matanyone_in_GPU = False matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) model_loaded = True else: # print("Matanyone Tab UnSelected") import gc - model.samcontroler.sam_controler.model.to("cpu") - matanyone_model.to("cpu") + # model.samcontroler.sam_controler.model.to("cpu") + # matanyone_model.to("cpu") + model = matanyone_model = matanyone_processor = None + matanyone_in_GPU = model_in_GPU = False gc.collect() torch.cuda.empty_cache() + model_loaded = False def get_vmc_event_handler(): @@ -563,13 +734,10 @@ def export_image_mask(image_input, image_mask): return Image.fromarray(image_input), image_mask -def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output): +def export_to_current_video_engine( foreground_video_output, alpha_video_output): gr.Info("Original Video and Full Mask have been transferred") # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output - if "custom_edit" in model_type and False: - return gr.update(), alpha_video_output - else: - return foreground_video_output, alpha_video_output + return foreground_video_output, alpha_video_output def teleport_to_video_tab(tab_state): @@ -578,17 +746,22 @@ def teleport_to_video_tab(tab_state): return gr.Tabs(selected="video_gen") -def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): +def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) + global image_output_codec, video_output_codec + + image_output_codec = server_config.get("image_output_codec", None) + video_output_codec = server_config.get("video_output_codec", None) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" # download assets - gr.Markdown("Mast Edition is provided by MatAnyone") + gr.Markdown("Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep") gr.Markdown("If you have some trouble creating the perfect mask, be aware of these tips:") gr.Markdown("- Using the Matanyone Settings you can also define Negative Point Prompts to remove parts of the current selection.") gr.Markdown("- Sometime it is very hard to fit everything you want in a single mask, it may be much easier to combine multiple independent sub Masks before producing the Matting : each sub Mask is created by selecting an area of an image and by clicking the Add Mask button. Sub masks can then be enabled / disabled in the Matanyone settings.") + gr.Markdown("The Mask Generation time and the VRAM consumed are proportional to the number of frames and the resolution. So if relevant, you may reduce the number of frames in the Matanyone Settings. You will need for the moment to resize yourself the video if needed.") with gr.Column( visible=True): with gr.Row(): @@ -719,7 +892,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, v with gr.Row(visible= True): export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False) - export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [model_choice, foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, + export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) @@ -768,7 +941,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, v inputs=[], outputs=[foreground_video_output, alpha_video_output]).then( fn=video_matting, - inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], + inputs=[video_state, video_input, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_14B_btn, export_to_current_video_engine_btn] ) @@ -909,7 +1082,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, v foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image") alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image") with gr.Row(equal_height=True): - bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False) + bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", visible = False, interactive= False) with gr.Row(): # with gr.Row(): export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") @@ -972,7 +1145,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, v matting_button.click( fn=image_matting, inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider], - outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] + outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] ) diff --git a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py index e7dd5e7..8c857ea 100644 --- a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py +++ b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py @@ -34,24 +34,38 @@ def get_similarity(mk: torch.Tensor, uncert_mask = uncert_mask.expand(-1, 64, -1) qk = qk * uncert_mask qe = qe * uncert_mask - + # Behold the work of DeeBeepMeep the Code Butcher ! if qe is not None: # See XMem's appendix for derivation mk = mk.transpose(1, 2) a_sq = (mk.pow(2) @ qe) - two_ab = 2 * (mk @ (qk * qe)) + two_ab = mk @ (qk * qe) + two_ab *= 2 + two_ab.sub_(a_sq) + del a_sq b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) - similarity = (-a_sq + two_ab - b_sq) + two_ab.sub_(b_sq) + similarity = two_ab + del b_sq, two_ab + # similarity = (-a_sq + two_ab - b_sq) else: # similar to STCN if we don't have the selection term a_sq = mk.pow(2).sum(1).unsqueeze(2) - two_ab = 2 * (mk.transpose(1, 2) @ qk) - similarity = (-a_sq + two_ab) + two_ab = mk.transpose(1, 2) @ qk + two_ab *= 2 + two_ab.sub_(a_sq) + del a_sq + similarity = two_ab + del two_ab + # similarity = (-a_sq + two_ab) if ms is not None: - similarity = similarity * ms / math.sqrt(CK) # B*N*HW + similarity *= ms + similarity /= math.sqrt(CK) + # similarity = similarity * ms / math.sqrt(CK) # B*N*HW else: - similarity = similarity / math.sqrt(CK) # B*N*HW + similarity /= math.sqrt(CK) + # similarity = similarity / math.sqrt(CK) # B*N*HW return similarity diff --git a/preprocessing/matanyone/matanyone_wrapper.py b/preprocessing/matanyone/matanyone_wrapper.py index 82fb773..292465a 100644 --- a/preprocessing/matanyone/matanyone_wrapper.py +++ b/preprocessing/matanyone/matanyone_wrapper.py @@ -47,9 +47,13 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10): frames = [] phas = [] + i = 0 for ti, frame_single in tqdm.tqdm(enumerate(frames_np)): image = to_tensor(frame_single).cuda().float() - + if i % 10 ==0: + pass + # torch.cuda.empty_cache() + i += 1 if ti == 0: output_prob = processor.step(image, mask, objects=objects) # encode given mask output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames diff --git a/requirements.txt b/requirements.txt index 312dc94..767a68d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,49 +1,61 @@ -torch>=2.4.0 -torchvision>=0.19.0 -opencv-python>=4.9.0.80 -diffusers>=0.31.0 -transformers==4.51.3 -#transformers==4.46.3 # was needed by llamallava used by i2v hunyuan before patch +# Core AI stack +diffusers==0.34.0 +transformers==4.53.1 tokenizers>=0.20.3 accelerate>=1.1.1 tqdm imageio -easydict -ftfy -dashscope imageio-ffmpeg -# flash_attn -gradio==5.23.0 -numpy>=1.23.5,<2 einops -moviepy==1.0.3 -mmgp==3.5.1 -peft==0.17.0 -mutagen -pydantic==2.10.6 -decord -onnxruntime-gpu -rembg[gpu]==2.0.65 -matplotlib -timm -segment-anything -omegaconf -hydra-core -librosa==0.11.0 -loguru sentencepiece +open_clip_torch>=2.29.0 + +# Video & media +moviepy==1.0.3 av -opencv-python +ffmpeg-python pygame>=2.1.0 sounddevice>=0.4.0 -# rembg==2.0.65 -torchdiffeq >= 0.2.5 -tensordict >= 0.6.1 -open_clip_torch >= 2.29.0 -pyloudnorm -misaki soundfile -ffmpeg-python -pyannote.audio +mutagen +pyloudnorm +librosa==0.11.0 + +# UI & interaction +gradio==5.23.0 +dashscope +loguru + +# Vision & segmentation +opencv-python>=4.9.0.80 +segment-anything +rembg[gpu]==2.0.65 +onnxruntime-gpu +decord +timm + +# Config & orchestration +omegaconf +hydra-core +easydict +pydantic==2.10.6 + +# Math & modeling +torchdiffeq>=0.2.5 +tensordict>=0.6.1 +mmgp==3.5.10 +peft==0.17.0 +matplotlib + +# Utilities +ftfy +piexif +pynvml +misaki + +# Optional / commented out +# transformers==4.46.3 # for llamallava pre-patch +# rembg==2.0.65 # non-GPU fallback +# huggingface_hub[hf_xet] # slows down everything # num2words # spacy diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py new file mode 100644 index 0000000..6e865fa --- /dev/null +++ b/shared/RGB_factors.py @@ -0,0 +1,267 @@ +# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) +def get_rgb_factors(model_family, model_type = None): + if model_family == "wan": + if model_type =="ti2v_2_2": + latent_channels = 48 + latent_dimensions = 3 + latent_rgb_factors = [ + [ 0.0119, 0.0103, 0.0046], + [-0.1062, -0.0504, 0.0165], + [ 0.0140, 0.0409, 0.0491], + [-0.0813, -0.0677, 0.0607], + [ 0.0656, 0.0851, 0.0808], + [ 0.0264, 0.0463, 0.0912], + [ 0.0295, 0.0326, 0.0590], + [-0.0244, -0.0270, 0.0025], + [ 0.0443, -0.0102, 0.0288], + [-0.0465, -0.0090, -0.0205], + [ 0.0359, 0.0236, 0.0082], + [-0.0776, 0.0854, 0.1048], + [ 0.0564, 0.0264, 0.0561], + [ 0.0006, 0.0594, 0.0418], + [-0.0319, -0.0542, -0.0637], + [-0.0268, 0.0024, 0.0260], + [ 0.0539, 0.0265, 0.0358], + [-0.0359, -0.0312, -0.0287], + [-0.0285, -0.1032, -0.1237], + [ 0.1041, 0.0537, 0.0622], + [-0.0086, -0.0374, -0.0051], + [ 0.0390, 0.0670, 0.2863], + [ 0.0069, 0.0144, 0.0082], + [ 0.0006, -0.0167, 0.0079], + [ 0.0313, -0.0574, -0.0232], + [-0.1454, -0.0902, -0.0481], + [ 0.0714, 0.0827, 0.0447], + [-0.0304, -0.0574, -0.0196], + [ 0.0401, 0.0384, 0.0204], + [-0.0758, -0.0297, -0.0014], + [ 0.0568, 0.1307, 0.1372], + [-0.0055, -0.0310, -0.0380], + [ 0.0239, -0.0305, 0.0325], + [-0.0663, -0.0673, -0.0140], + [-0.0416, -0.0047, -0.0023], + [ 0.0166, 0.0112, -0.0093], + [-0.0211, 0.0011, 0.0331], + [ 0.1833, 0.1466, 0.2250], + [-0.0368, 0.0370, 0.0295], + [-0.3441, -0.3543, -0.2008], + [-0.0479, -0.0489, -0.0420], + [-0.0660, -0.0153, 0.0800], + [-0.0101, 0.0068, 0.0156], + [-0.0690, -0.0452, -0.0927], + [-0.0145, 0.0041, 0.0015], + [ 0.0421, 0.0451, 0.0373], + [ 0.0504, -0.0483, -0.0356], + [-0.0837, 0.0168, 0.0055] + ] + else: + latent_channels = 16 + latent_dimensions = 3 + latent_rgb_factors = [ + [-0.1299, -0.1692, 0.2932], + [ 0.0671, 0.0406, 0.0442], + [ 0.3568, 0.2548, 0.1747], + [ 0.0372, 0.2344, 0.1420], + [ 0.0313, 0.0189, -0.0328], + [ 0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [ 0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [ 0.0680, 0.3019, 0.1128], + [ 0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [ 0.0060, -0.0633, 0.0005], + [ 0.3477, 0.2275, 0.2950], + [ 0.1984, 0.0913, 0.1861] + ] + + latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] + + # latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + elif model_family =="flux": + scale_factor = 0.3611 + shift_factor = 0.1159 + latent_rgb_factors =[ + [-0.0346, 0.0244, 0.0681], + [ 0.0034, 0.0210, 0.0687], + [ 0.0275, -0.0668, -0.0433], + [-0.0174, 0.0160, 0.0617], + [ 0.0859, 0.0721, 0.0329], + [ 0.0004, 0.0383, 0.0115], + [ 0.0405, 0.0861, 0.0915], + [-0.0236, -0.0185, -0.0259], + [-0.0245, 0.0250, 0.1180], + [ 0.1008, 0.0755, -0.0421], + [-0.0515, 0.0201, 0.0011], + [ 0.0428, -0.0012, -0.0036], + [ 0.0817, 0.0765, 0.0749], + [-0.1264, -0.0522, -0.1103], + [-0.0280, -0.0881, -0.0499], + [-0.1262, -0.0982, -0.0778] + ] + latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + + elif model_family == "ltxv": + latent_channels = 128 + latent_dimensions = 3 + + latent_rgb_factors = [ + [ 1.1202e-02, -6.3815e-04, -1.0021e-02], + [ 8.6031e-02, 6.5813e-02, 9.5409e-04], + [-1.2576e-02, -7.5734e-03, -4.0528e-03], + [ 9.4063e-03, -2.1688e-03, 2.6093e-03], + [ 3.7636e-03, 1.2765e-02, 9.1548e-03], + [ 2.1024e-02, -5.2973e-03, 3.4373e-03], + [-8.8896e-03, -1.9703e-02, -1.8761e-02], + [-1.3160e-02, -1.0523e-02, 1.9709e-03], + [-1.5152e-03, -6.9891e-03, -7.5810e-03], + [-1.7247e-03, 4.6560e-04, -3.3839e-03], + [ 1.3617e-02, 4.7077e-03, -2.0045e-03], + [ 1.0256e-02, 7.7318e-03, 1.3948e-02], + [-1.6108e-02, -6.2151e-03, 1.1561e-03], + [ 7.3407e-03, 1.5628e-02, 4.4865e-04], + [ 9.5357e-04, -2.9518e-03, -1.4760e-02], + [ 1.9143e-02, 1.0868e-02, 1.2264e-02], + [ 4.4575e-03, 3.6682e-05, -6.8508e-03], + [-4.5681e-04, 3.2570e-03, 7.7929e-03], + [ 3.3902e-02, 3.3405e-02, 3.7454e-02], + [-2.3001e-02, -2.4877e-03, -3.1033e-03], + [ 5.0265e-02, 3.8841e-02, 3.3539e-02], + [-4.1018e-03, -1.1095e-03, 1.5859e-03], + [-1.2689e-01, -1.3107e-01, -2.1005e-01], + [ 2.6276e-02, 1.4189e-02, -3.5963e-03], + [-4.8679e-03, 8.8486e-03, 7.8029e-03], + [-1.6610e-03, -4.8597e-03, -5.2060e-03], + [-2.1010e-03, 2.3610e-03, 9.3796e-03], + [-2.2482e-02, -2.1305e-02, -1.5087e-02], + [-1.5753e-02, -1.0646e-02, -6.5083e-03], + [-4.6975e-03, 5.0288e-03, -6.7390e-03], + [ 1.1951e-02, 2.0712e-02, 1.6191e-02], + [-6.3704e-03, -8.4827e-03, -9.5483e-03], + [ 7.2610e-03, -9.9326e-03, -2.2978e-02], + [-9.1904e-04, 6.2882e-03, 9.5720e-03], + [-3.7178e-02, -3.7123e-02, -5.6713e-02], + [-1.3373e-01, -1.0720e-01, -5.3801e-02], + [-5.3702e-03, 8.1256e-03, 8.8397e-03], + [-1.5247e-01, -2.1437e-01, -2.1843e-01], + [ 3.1441e-02, 7.0335e-03, -9.7541e-03], + [ 2.1528e-03, -8.9817e-03, -2.1023e-02], + [ 3.8461e-03, -5.8957e-03, -1.5014e-02], + [-4.3470e-03, -1.2940e-02, -1.5972e-02], + [-5.4781e-03, -1.0842e-02, -3.0204e-03], + [-6.5347e-03, 3.0806e-03, -1.0163e-02], + [-5.0414e-03, -7.1503e-03, -8.9686e-04], + [-8.5851e-03, -2.4351e-03, 1.0674e-03], + [-9.0016e-03, -9.6493e-03, 1.5692e-03], + [ 5.0914e-03, 1.2099e-02, 1.9968e-02], + [ 1.3758e-02, 1.1669e-02, 8.1958e-03], + [-1.0518e-02, -1.1575e-02, -4.1307e-03], + [-2.8410e-02, -3.1266e-02, -2.2149e-02], + [ 2.9336e-03, 3.6511e-02, 1.8717e-02], + [-1.6703e-02, -1.6696e-02, -4.4529e-03], + [ 4.8818e-02, 4.0063e-02, 8.7410e-03], + [-1.5066e-02, -5.7328e-04, 2.9785e-03], + [-1.7613e-02, -8.1034e-03, 1.3086e-02], + [-9.2633e-03, 1.0803e-02, -6.3489e-03], + [ 3.0851e-03, 4.7750e-04, 1.2347e-02], + [-2.2785e-02, -2.3043e-02, -2.6005e-02], + [-2.4787e-02, -1.5389e-02, -2.2104e-02], + [-2.3572e-02, 1.0544e-03, 1.2361e-02], + [-7.8915e-03, -1.2271e-03, -6.0968e-03], + [-1.1478e-02, -1.2543e-03, 6.2679e-03], + [-5.4229e-02, 2.6644e-02, 6.3394e-03], + [ 4.4216e-03, -7.3338e-03, -1.0464e-02], + [-4.5013e-03, 1.6082e-03, 1.4420e-02], + [ 1.3673e-02, 8.8877e-03, 4.1253e-03], + [-1.0145e-02, 9.0072e-03, 1.5695e-02], + [-5.6234e-03, 1.1847e-03, 8.1261e-03], + [-3.7171e-03, -5.3538e-03, 1.2590e-03], + [ 2.9476e-02, 2.1424e-02, 3.0424e-02], + [-3.4925e-02, -2.4340e-02, -2.5316e-02], + [-3.4127e-02, -2.2406e-02, -1.0589e-02], + [-1.7342e-02, -1.3249e-02, -1.0719e-02], + [-2.1478e-03, -8.6051e-03, -2.9878e-03], + [ 1.2089e-03, -4.2391e-03, -6.8569e-03], + [ 9.0411e-04, -6.6886e-03, -6.7547e-05], + [ 1.6048e-02, -1.0057e-02, -2.8929e-02], + [ 1.2290e-03, 1.0163e-02, 1.8861e-02], + [ 1.7264e-02, 2.7257e-04, 1.3785e-02], + [-1.3482e-02, -3.6427e-03, 6.7481e-04], + [ 4.6782e-03, -5.2423e-03, 2.4467e-03], + [-5.9113e-03, -6.2244e-03, -1.8162e-03], + [ 1.5496e-02, 1.4582e-02, 1.9514e-03], + [ 7.4958e-03, 1.5886e-03, -8.2305e-03], + [ 1.9086e-02, 1.6360e-03, -3.9674e-03], + [-5.7021e-03, -2.7307e-03, -4.1066e-03], + [ 1.7450e-03, 1.4602e-02, 2.5794e-02], + [-8.2788e-04, 2.2902e-03, 4.5161e-03], + [ 1.1632e-02, 8.9193e-03, -7.2813e-03], + [ 7.5721e-03, 2.6784e-03, 1.1393e-02], + [ 5.1939e-03, 3.6903e-03, 1.4049e-02], + [-1.8383e-02, -2.2529e-02, -2.4477e-02], + [ 5.8842e-04, -5.7874e-03, -1.4770e-02], + [-1.6125e-02, -8.6101e-03, -1.4533e-02], + [ 2.0540e-02, 2.0729e-02, 6.4338e-03], + [ 3.3587e-03, -1.1226e-02, -1.6444e-02], + [-1.4742e-03, -1.0489e-02, 1.7097e-03], + [ 2.8130e-02, 2.3546e-02, 3.2791e-02], + [-1.8532e-02, -1.2842e-02, -8.7756e-03], + [-8.0533e-03, -1.0771e-02, -1.7536e-02], + [-3.9009e-03, 1.6150e-02, 3.3359e-02], + [-7.4554e-03, -1.4154e-02, -6.1910e-03], + [ 3.4734e-03, -1.1370e-02, -1.0581e-02], + [ 1.1476e-02, 3.9281e-03, 2.8231e-03], + [ 7.1639e-03, -1.4741e-03, -3.8066e-03], + [ 2.2250e-03, -8.7552e-03, -9.5719e-03], + [ 2.4146e-02, 2.1696e-02, 2.8056e-02], + [-5.4365e-03, -2.4291e-02, -1.7802e-02], + [ 7.4263e-03, 1.0510e-02, 1.2705e-02], + [ 6.2669e-03, 6.2658e-03, 1.9211e-02], + [ 1.6378e-02, 9.4933e-03, 6.6971e-03], + [ 1.7173e-02, 2.3601e-02, 2.3296e-02], + [-1.4568e-02, -9.8279e-03, -1.1556e-02], + [ 1.4431e-02, 1.4430e-02, 6.6362e-03], + [-6.8230e-03, 1.8863e-02, 1.4555e-02], + [ 6.1156e-03, 3.4700e-03, -2.6662e-03], + [-2.6983e-03, -5.9402e-03, -9.2276e-03], + [ 1.0235e-02, 7.4173e-03, -7.6243e-03], + [-1.3255e-02, 1.9322e-02, -9.2153e-04], + [ 2.4222e-03, -4.8039e-03, -1.5759e-02], + [ 2.6244e-02, 2.5951e-02, 2.0249e-02], + [ 1.5711e-02, 1.8498e-02, 2.7407e-03], + [-2.1714e-03, 4.7214e-03, -2.2443e-02], + [-7.4747e-03, 7.4166e-03, 1.4430e-02], + [-8.3906e-03, -7.9776e-03, 9.7927e-03], + [ 3.8321e-02, 9.6622e-03, -1.9268e-02], + [-1.4605e-02, -6.7032e-03, 3.9675e-03] + ] + latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] + + elif model_family == "hunyuan": + latent_channels = 16 + latent_dimensions = 3 + scale_factor = 0.476986 + latent_rgb_factors = [ + [-0.0395, -0.0331, 0.0445], + [ 0.0696, 0.0795, 0.0518], + [ 0.0135, -0.0945, -0.0282], + [ 0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], + [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], + [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], + [-0.0799, -0.0208, -0.0375], + [ 0.1166, 0.1627, 0.0962], + [ 0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], + [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], + [ 0.0249, -0.0469, -0.1703] + ] + + latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + else: + latent_rgb_factors_bias = latent_rgb_factors = None + return latent_rgb_factors, latent_rgb_factors_bias \ No newline at end of file diff --git a/wan/modules/attention.py b/shared/attention.py similarity index 100% rename from wan/modules/attention.py rename to shared/attention.py diff --git a/shared/extract_lora.py b/shared/extract_lora.py new file mode 100644 index 0000000..2a23582 --- /dev/null +++ b/shared/extract_lora.py @@ -0,0 +1,573 @@ +import torch +import torch.nn as nn +from typing import Dict, Tuple, Optional, Union +import warnings + +try: + from safetensors.torch import save_file as save_safetensors + SAFETENSORS_AVAILABLE = True +except ImportError: + SAFETENSORS_AVAILABLE = False + warnings.warn("safetensors not available. Install with: pip install safetensors") + +class LoRAExtractor: + """ + Extract LoRA tensors from the difference between original and fine-tuned models. + + LoRA (Low-Rank Adaptation) decomposes weight updates as ΔW = B @ A where: + - A (lora_down): [rank, input_dim] matrix (saved as diffusion_model.param_name.lora_down.weight) + - B (lora_up): [output_dim, rank] matrix (saved as diffusion_model.param_name.lora_up.weight) + + The decomposition uses SVD: ΔW = U @ S @ V^T ≈ (U @ S) @ V^T where: + - lora_up = U @ S (contains all singular values) + - lora_down = V^T (orthogonal matrix) + + Parameter handling based on name AND dimension: + - 2D weight tensors: LoRA decomposition (.lora_down.weight, .lora_up.weight) + - Any bias tensors: direct difference (.diff_b) + - Other weight tensors (1D, 3D, 4D): full difference (.diff) + + Progress tracking and test mode are available for format validation and debugging. + """ + + def __init__(self, rank: int = 128, threshold: float = 1e-6, test_mode: bool = False, show_reconstruction_errors: bool = False): + """ + Initialize LoRA extractor. + + Args: + rank: Target rank for LoRA decomposition (default: 128) + threshold: Minimum singular value threshold for decomposition + test_mode: If True, creates zero tensors without computation for format testing + show_reconstruction_errors: If True, calculates and displays reconstruction error for each LoRA pair + """ + self.rank = rank + self.threshold = threshold + self.test_mode = test_mode + self.show_reconstruction_errors = show_reconstruction_errors + + def extract_lora_from_state_dicts( + self, + original_state_dict: Dict[str, torch.Tensor], + finetuned_state_dict: Dict[str, torch.Tensor], + device: str = 'cpu', + show_progress: bool = True + ) -> Dict[str, torch.Tensor]: + """ + Extract LoRA tensors for all matching parameters between two state dictionaries. + + Args: + original_state_dict: State dict of the original model + finetuned_state_dict: State dict of the fine-tuned model + device: Device to perform computations on + show_progress: Whether to display progress information + + Returns: + Dictionary mapping parameter names to their LoRA components: + - For 2D weight tensors: 'diffusion_model.layer.lora_down.weight', 'diffusion_model.layer.lora_up.weight' + - For any bias tensors: 'diffusion_model.layer.diff_b' + - For other weight tensors (1D, 3D, 4D): 'diffusion_model.layer.diff' + """ + lora_tensors = {} + + # Find common parameters and sort alphabetically for consistent processing order + common_keys = sorted(set(original_state_dict.keys()) & set(finetuned_state_dict.keys())) + total_params = len(common_keys) + processed_params = 0 + extracted_components = 0 + + if show_progress: + print(f"Starting LoRA extraction for {total_params} parameters on {device}...") + + # Pre-move threshold to device for faster comparisons + threshold_tensor = torch.tensor(self.threshold, device=device) + + for param_name in common_keys: + if show_progress: + processed_params += 1 + progress_pct = (processed_params / total_params) * 100 + print(f"[{processed_params:4d}/{total_params}] ({progress_pct:5.1f}%) Processing: {param_name}") + + # Move tensors to device once + original_tensor = original_state_dict[param_name] + finetuned_tensor = finetuned_state_dict[param_name] + + # Check if tensors have the same shape before moving to device + if original_tensor.shape != finetuned_tensor.shape: + if show_progress: + print(f" → Shape mismatch: {original_tensor.shape} vs {finetuned_tensor.shape}. Skipping.") + continue + + # Move to device and compute difference in one go for efficiency (skip in test mode) + if not self.test_mode: + if original_tensor.device != torch.device(device): + original_tensor = original_tensor.to(device, non_blocking=True) + if finetuned_tensor.device != torch.device(device): + finetuned_tensor = finetuned_tensor.to(device, non_blocking=True) + + # Compute difference on device + delta_tensor = finetuned_tensor - original_tensor + + # Fast GPU-based threshold check + max_abs_diff = torch.max(torch.abs(delta_tensor)) + if max_abs_diff <= threshold_tensor: + if show_progress: + print(f" → No significant changes detected (max diff: {max_abs_diff:.2e}), skipping") + continue + else: + # Test mode - create dummy delta tensor with original shape and dtype + delta_tensor = torch.zeros_like(original_tensor) + if device != 'cpu': + delta_tensor = delta_tensor.to(device) + + # Extract LoRA components based on tensor dimensionality + extracted_tensors = self._extract_lora_components(delta_tensor, param_name) + + if extracted_tensors: + lora_tensors.update(extracted_tensors) + extracted_components += len(extracted_tensors) + if show_progress: + # Show meaningful component names instead of just 'weight' + component_names = [] + for key in extracted_tensors.keys(): + if key.endswith('.lora_down.weight'): + component_names.append('lora_down') + elif key.endswith('.lora_up.weight'): + component_names.append('lora_up') + elif key.endswith('.diff_b'): + component_names.append('diff_b') + elif key.endswith('.diff'): + component_names.append('diff') + else: + component_names.append(key.split('.')[-1]) + print(f" → Extracted {len(extracted_tensors)} components: {component_names}") + + if show_progress: + print(f"\nExtraction completed!") + print(f"Processed: {processed_params}/{total_params} parameters") + print(f"Extracted: {extracted_components} LoRA components") + print(f"LoRA rank: {self.rank}") + + # Summary by type + lora_down_count = sum(1 for k in lora_tensors.keys() if k.endswith('.lora_down.weight')) + lora_up_count = sum(1 for k in lora_tensors.keys() if k.endswith('.lora_up.weight')) + diff_b_count = sum(1 for k in lora_tensors.keys() if k.endswith('.diff_b')) + diff_count = sum(1 for k in lora_tensors.keys() if k.endswith('.diff')) + + print(f"Summary: {lora_down_count} lora_down, {lora_up_count} lora_up, {diff_b_count} diff_b, {diff_count} diff") + + return lora_tensors + + def _extract_lora_components( + self, + delta_tensor: torch.Tensor, + param_name: str + ) -> Optional[Dict[str, torch.Tensor]]: + """ + Extract LoRA components from a delta tensor. + + Args: + delta_tensor: Difference between fine-tuned and original tensor + param_name: Name of the parameter (for generating output keys) + + Returns: + Dictionary with modified parameter names as keys and tensors as values + """ + # Determine if this is a weight or bias parameter from the original name + is_weight = 'weight' in param_name.lower() + is_bias = 'bias' in param_name.lower() + + # Remove .weight or .bias suffix from parameter name + base_name = param_name + if base_name.endswith('.weight'): + base_name = base_name[:-7] # Remove '.weight' + elif base_name.endswith('.bias'): + base_name = base_name[:-5] # Remove '.bias' + + # Add diffusion_model prefix + base_name = f"diffusion_model.{base_name}" + + if self.test_mode: + # Fast test mode - create zero tensors without computation + if delta_tensor.dim() == 2 and is_weight: + # 2D weight tensor -> LoRA decomposition + output_dim, input_dim = delta_tensor.shape + rank = min(self.rank, min(input_dim, output_dim)) + return { + f"{base_name}.lora_down.weight": torch.zeros(rank, input_dim, dtype=delta_tensor.dtype, device=delta_tensor.device), + f"{base_name}.lora_up.weight": torch.zeros(output_dim, rank, dtype=delta_tensor.dtype, device=delta_tensor.device) + } + elif is_bias: + # Any bias tensor (1D, 2D, etc.) -> .diff_b + return {f"{base_name}.diff_b": torch.zeros_like(delta_tensor)} + else: + # Any weight tensor that's not 2D, or other tensors -> .diff + return {f"{base_name}.diff": torch.zeros_like(delta_tensor)} + + # Normal mode - check dimensions AND parameter type + if delta_tensor.dim() == 2 and is_weight: + # 2D weight tensor (linear layer weight) - apply SVD decomposition + return self._decompose_2d_tensor(delta_tensor, base_name) + + elif is_bias: + # Any bias tensor (regardless of dimension) - save as .diff_b + return {f"{base_name}.diff_b": delta_tensor.clone()} + + else: + # Any other tensor (weight tensors that are 1D, 3D, 4D, or unknown tensors) - save as .diff + return {f"{base_name}.diff": delta_tensor.clone()} + + def _decompose_2d_tensor(self, delta_tensor: torch.Tensor, base_name: str) -> Dict[str, torch.Tensor]: + """ + Decompose a 2D tensor using SVD on GPU for maximum performance. + + Args: + delta_tensor: 2D tensor to decompose (output_dim × input_dim) + base_name: Base name for the parameter (already processed, with diffusion_model prefix) + + Returns: + Dictionary with lora_down and lora_up tensors: + - lora_down: [rank, input_dim] + - lora_up: [output_dim, rank] + """ + # Store original dtype and device + dtype = delta_tensor.dtype + device = delta_tensor.device + + # Perform SVD in float32 for numerical stability, but keep on same device + delta_float = delta_tensor.float() if delta_tensor.dtype != torch.float32 else delta_tensor + U, S, Vt = torch.linalg.svd(delta_float, full_matrices=False) + + # Determine effective rank (number of significant singular values) + # Use GPU-accelerated operations + significant_mask = S > self.threshold + effective_rank = min(self.rank, torch.sum(significant_mask).item()) + effective_rank = self.rank + + if effective_rank == 0: + warnings.warn(f"No significant singular values found for {base_name}") + effective_rank = 1 + + # Create LoRA matrices with correct SVD decomposition + # Standard approach: put all singular values in lora_up, leave lora_down as V^T + # This ensures: lora_up @ lora_down = (U @ S) @ V^T = U @ S @ V^T = ΔW ✓ + + lora_up = U[:, :effective_rank] * S[:effective_rank].unsqueeze(0) # [output_dim, rank] + lora_down = Vt[:effective_rank, :] # [rank, input_dim] + + # Convert back to original dtype (keeping on same device) + lora_up = lora_up.to(dtype) + lora_down = lora_down.to(dtype) + + # Calculate and display reconstruction error if requested + if self.show_reconstruction_errors: + with torch.no_grad(): + # Reconstruct the original delta tensor + reconstructed = lora_up @ lora_down + + # Calculate various error metrics + mse_error = torch.mean((delta_tensor - reconstructed) ** 2).item() + max_error = torch.max(torch.abs(delta_tensor - reconstructed)).item() + + # Relative error + original_norm = torch.norm(delta_tensor).item() + relative_error = (torch.norm(delta_tensor - reconstructed).item() / original_norm * 100) if original_norm > 0 else 0 + + # Cosine similarity + delta_flat = delta_tensor.flatten() + reconstructed_flat = reconstructed.flatten() + if torch.norm(delta_flat) > 0 and torch.norm(reconstructed_flat) > 0: + cosine_sim = torch.nn.functional.cosine_similarity( + delta_flat.unsqueeze(0), + reconstructed_flat.unsqueeze(0) + ).item() + else: + cosine_sim = 0.0 + + # Extract parameter name for display (remove diffusion_model prefix) + display_name = base_name[16:] if base_name.startswith('diffusion_model.') else base_name + + print(f" LoRA Error [{display_name}]: MSE={mse_error:.2e}, Max={max_error:.2e}, Rel={relative_error:.2f}%, Cos={cosine_sim:.4f}, Rank={effective_rank}") + + return { + f"{base_name}.lora_down.weight": lora_down, + f"{base_name}.lora_up.weight": lora_up + } + + def verify_reconstruction( + self, + lora_tensors: Dict[str, torch.Tensor], + original_deltas: Dict[str, torch.Tensor] + ) -> Dict[str, float]: + """ + Verify the quality of LoRA reconstruction for 2D tensors. + + Args: + lora_tensors: Dictionary with LoRA tensors (flat structure with diffusion_model prefix) + original_deltas: Dictionary with original delta tensors (without prefix) + + Returns: + Dictionary mapping parameter names to reconstruction errors + """ + reconstruction_errors = {} + + # Group LoRA components by base parameter name + lora_pairs = {} + for key, tensor in lora_tensors.items(): + if key.endswith('.lora_down.weight'): + base_name = key[:-18] # Remove '.lora_down.weight' + # Remove diffusion_model prefix for matching with original_deltas + if base_name.startswith('diffusion_model.'): + original_key = base_name[16:] # Remove 'diffusion_model.' + else: + original_key = base_name + if base_name not in lora_pairs: + lora_pairs[base_name] = {'original_key': original_key} + lora_pairs[base_name]['lora_down'] = tensor + elif key.endswith('.lora_up.weight'): + base_name = key[:-16] # Remove '.lora_up.weight' + # Remove diffusion_model prefix for matching with original_deltas + if base_name.startswith('diffusion_model.'): + original_key = base_name[16:] # Remove 'diffusion_model.' + else: + original_key = base_name + if base_name not in lora_pairs: + lora_pairs[base_name] = {'original_key': original_key} + lora_pairs[base_name]['lora_up'] = tensor + + # Verify reconstruction for each complete LoRA pair + for base_name, components in lora_pairs.items(): + if 'lora_down' in components and 'lora_up' in components and 'original_key' in components: + original_key = components['original_key'] + if original_key in original_deltas: + lora_down = components['lora_down'] + lora_up = components['lora_up'] + original_delta = original_deltas[original_key] + + # Get effective rank from the actual tensor dimensions + effective_rank = min(lora_up.shape[1], lora_down.shape[0]) + + # Reconstruct: ΔW = lora_up @ lora_down (no additional scaling needed since it's built into lora_up) + reconstructed = lora_up @ lora_down + + # Compute reconstruction error + mse_error = torch.mean((original_delta - reconstructed) ** 2).item() + reconstruction_errors[base_name] = mse_error + + return reconstruction_errors + +def compute_reconstruction_errors( + original_tensor: torch.Tensor, + reconstructed_tensor: torch.Tensor, + target_tensor: torch.Tensor +) -> Dict[str, float]: + """ + Compute various error metrics between original, reconstructed, and target tensors. + + Args: + original_tensor: Original tensor before fine-tuning + reconstructed_tensor: Reconstructed tensor from LoRA (original + LoRA_reconstruction) + target_tensor: Target tensor (fine-tuned) + + Returns: + Dictionary with error metrics + """ + # Ensure all tensors are on the same device and have the same shape + device = original_tensor.device + reconstructed_tensor = reconstructed_tensor.to(device) + target_tensor = target_tensor.to(device) + + # Compute differences + delta_original = target_tensor - original_tensor # True fine-tuning difference + delta_reconstructed = reconstructed_tensor - original_tensor # LoRA reconstructed difference + reconstruction_error = target_tensor - reconstructed_tensor # Final reconstruction error + + # Compute various error metrics + errors = {} + + # Mean Squared Error (MSE) + errors['mse_delta'] = torch.mean((delta_original - delta_reconstructed) ** 2).item() + errors['mse_final'] = torch.mean(reconstruction_error ** 2).item() + + # Mean Absolute Error (MAE) + errors['mae_delta'] = torch.mean(torch.abs(delta_original - delta_reconstructed)).item() + errors['mae_final'] = torch.mean(torch.abs(reconstruction_error)).item() + + # Relative errors (as percentages) + original_norm = torch.norm(original_tensor).item() + target_norm = torch.norm(target_tensor).item() + delta_norm = torch.norm(delta_original).item() + + if original_norm > 0: + errors['relative_error_original'] = (torch.norm(reconstruction_error).item() / original_norm) * 100 + if target_norm > 0: + errors['relative_error_target'] = (torch.norm(reconstruction_error).item() / target_norm) * 100 + if delta_norm > 0: + errors['relative_error_delta'] = (torch.norm(delta_original - delta_reconstructed).item() / delta_norm) * 100 + + # Cosine similarity (higher is better, 1.0 = perfect) + delta_flat = delta_original.flatten() + reconstructed_flat = delta_reconstructed.flatten() + + if torch.norm(delta_flat) > 0 and torch.norm(reconstructed_flat) > 0: + cosine_sim = torch.nn.functional.cosine_similarity( + delta_flat.unsqueeze(0), + reconstructed_flat.unsqueeze(0) + ).item() + errors['cosine_similarity'] = cosine_sim + else: + errors['cosine_similarity'] = 0.0 + + # Signal-to-noise ratio (SNR) in dB + if errors['mse_final'] > 0: + signal_power = torch.mean(target_tensor ** 2).item() + errors['snr_db'] = 10 * torch.log10(signal_power / errors['mse_final']).item() + else: + errors['snr_db'] = float('inf') + + return errors + +# Example usage and utility functions +def load_and_extract_lora( + original_model_path: str, + finetuned_model_path: str, + rank: int = 128, + device: str = 'cuda' if torch.cuda.is_available() else 'cpu', + show_progress: bool = True, + test_mode: bool = False, + show_reconstruction_errors: bool = False +) -> Dict[str, torch.Tensor]: + """ + Convenience function to load models and extract LoRA tensors with GPU acceleration. + + Args: + original_model_path: Path to original model state dict + finetuned_model_path: Path to fine-tuned model state dict + rank: Target LoRA rank (default: 128) + device: Device for computation (defaults to GPU if available) + show_progress: Whether to display progress information + test_mode: If True, creates zero tensors without computation for format testing + show_reconstruction_errors: If True, calculates and displays reconstruction error for each LoRA pair + + Returns: + Dictionary of LoRA tensors with modified parameter names as keys + """ + # Load state dictionaries directly to CPU first (safetensors loads to CPU by default) + if show_progress: + print(f"Loading original model from: {original_model_path}") + original_state_dict = torch.load(original_model_path, map_location='cpu') + + if show_progress: + print(f"Loading fine-tuned model from: {finetuned_model_path}") + finetuned_state_dict = torch.load(finetuned_model_path, map_location='cpu') + + # Handle nested state dicts (if wrapped in 'model' key or similar) + if 'state_dict' in original_state_dict: + original_state_dict = original_state_dict['state_dict'] + if 'state_dict' in finetuned_state_dict: + finetuned_state_dict = finetuned_state_dict['state_dict'] + + # Extract LoRA tensors with GPU acceleration + extractor = LoRAExtractor(rank=rank, test_mode=test_mode, show_reconstruction_errors=show_reconstruction_errors) + lora_tensors = extractor.extract_lora_from_state_dicts( + original_state_dict, + finetuned_state_dict, + device=device, + show_progress=show_progress + ) + + return lora_tensors + +def save_lora_tensors(lora_tensors: Dict[str, torch.Tensor], save_path: str): + """Save extracted LoRA tensors to disk.""" + torch.save(lora_tensors, save_path) + print(f"LoRA tensors saved to {save_path}") + +def save_lora_safetensors(lora_tensors: Dict[str, torch.Tensor], save_path: str, rank: int = None): + """Save extracted LoRA tensors as safetensors format with metadata.""" + if not SAFETENSORS_AVAILABLE: + raise ImportError("safetensors not available. Install with: pip install safetensors") + + # Ensure all tensors are contiguous for safetensors + contiguous_tensors = {k: v.contiguous() if v.is_floating_point() else v.contiguous() + for k, v in lora_tensors.items()} + + # Add rank as metadata if provided + metadata = {} + if rank is not None: + metadata["rank"] = str(rank) + + save_safetensors(contiguous_tensors, save_path, metadata=metadata if metadata else None) + print(f"LoRA tensors saved as safetensors to {save_path}") + if metadata: + print(f"Metadata: {metadata}") + +def analyze_lora_tensors(lora_tensors: Dict[str, torch.Tensor]): + """Analyze the extracted LoRA tensors.""" + print(f"Extracted LoRA tensors ({len(lora_tensors)} components):") + + # Group by type for better organization + lora_down_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.lora_down.weight')} + lora_up_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.lora_up.weight')} + diff_b_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.diff_b')} + diff_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.diff')} + + if lora_down_tensors: + print(f"\nLinear LoRA down matrices ({len(lora_down_tensors)}):") + for name, tensor in lora_down_tensors.items(): + print(f" {name}: {tensor.shape}") + + if lora_up_tensors: + print(f"\nLinear LoRA up matrices ({len(lora_up_tensors)}):") + for name, tensor in lora_up_tensors.items(): + print(f" {name}: {tensor.shape}") + + if diff_b_tensors: + print(f"\nBias differences ({len(diff_b_tensors)}):") + for name, tensor in diff_b_tensors.items(): + print(f" {name}: {tensor.shape}") + + if diff_tensors: + print(f"\nFull weight differences ({len(diff_tensors)}):") + print(" (Includes conv, modulation, and other multi-dimensional tensors)") + for name, tensor in diff_tensors.items(): + print(f" {name}: {tensor.shape}") + +# Example usage +if __name__ == "__main__": + + + from safetensors.torch import load_file as load_safetensors + + # Load original and fine-tuned models from safetensors files + + original_state_dict = load_safetensors("ckpts/wan2.2_text2video_14B_high_mbf16.safetensors") + finetuned_state_dict = load_safetensors("ckpts/wan2.2_text2video_14B_low_mbf16.safetensors") + + # original_state_dict = load_safetensors("ckpts/flux1-dev_bf16.safetensors") + # finetuned_state_dict = load_safetensors("ckpts/flux1-schnell_bf16.safetensors") + + print(f"Loaded original model with {len(original_state_dict)} parameters") + print(f"Loaded fine-tuned model with {len(finetuned_state_dict)} parameters") + + # extractor_test = LoRAExtractor(test_mode=True) + + extractor_test = LoRAExtractor(show_reconstruction_errors=True, rank=128) + + lora_tensors_test = extractor_test.extract_lora_from_state_dicts( + original_state_dict, + finetuned_state_dict, + device='cuda', + show_progress=True + ) + + print("\nTest mode tensor keys (first 10):") + for i, key in enumerate(sorted(lora_tensors_test.keys())): + if i < 10: + print(f" {key}: {lora_tensors_test[key].shape}") + elif i == 10: + print(f" ... and {len(lora_tensors_test) - 10} more") + break + + # Always save as extracted_lora.safetensors for easier testing + save_lora_safetensors(lora_tensors_test, "extracted_lora.safetensors") + diff --git a/shared/match_archi.py b/shared/match_archi.py new file mode 100644 index 0000000..7d535d5 --- /dev/null +++ b/shared/match_archi.py @@ -0,0 +1,64 @@ +import re + +def match_nvidia_architecture(conditions_dict, architecture): + """ + Match Nvidia architecture against condition dictionary. + + Args: + conditions_dict: dict with condition strings as keys, parameters as values + architecture: int representing architecture (e.g., 89 for Ada Lovelace) + + Returns: + list of matched parameters + + Condition syntax: + - Operators: '<', '>', '<=', '>=', '=' (or no operator for equality) + - OR: '+' between conditions (e.g., '<=50+>89') + - AND: '&' between conditions (e.g., '>=70&<90') + - Examples: + * '<89': architectures below Ada (89) + * '>=75': architectures 75 and above + * '89': exactly Ada architecture + * '<=50+>89': Maxwell (50) and below OR above Ada + * '>=70&<90': Ampere range (70-89) + """ + + def eval_condition(cond, arch): + """Evaluate single condition against architecture""" + cond = cond.strip() + if not cond: + return False + + # Parse operator and value using regex + match = re.match(r'(>=|<=|>|<|=?)(\d+)', cond) + if not match: + return False + + op, val = match.groups() + val = int(val) + + # Handle operators + if op in ('', '='): + return arch == val + elif op == '>=': + return arch >= val + elif op == '<=': + return arch <= val + elif op == '>': + return arch > val + elif op == '<': + return arch < val + return False + + def matches_condition(condition_str, arch): + """Check if architecture matches full condition string""" + # Split by '+' for OR conditions, then by '&' for AND conditions + return any( + all(eval_condition(and_cond, arch) for and_cond in or_cond.split('&')) + for or_cond in condition_str.split('+') + if or_cond.strip() + ) + + # Return all parameters where conditions match + return [params for condition, params in conditions_dict.items() + if matches_condition(condition, architecture)] \ No newline at end of file diff --git a/wan/modules/sage2_core.py b/shared/sage2_core.py similarity index 100% rename from wan/modules/sage2_core.py rename to shared/sage2_core.py diff --git a/wan/utils/__init__.py b/shared/utils/__init__.py similarity index 100% rename from wan/utils/__init__.py rename to shared/utils/__init__.py diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py new file mode 100644 index 0000000..b24530d --- /dev/null +++ b/shared/utils/audio_video.py @@ -0,0 +1,421 @@ +import subprocess +import tempfile, os +import ffmpeg +import torchvision.transforms.functional as TF +import torch.nn.functional as F +import cv2 +import tempfile +import imageio +import binascii +import torchvision +import torch +from PIL import Image +import os.path as osp +import json + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + + +def extract_audio_tracks(source_video, verbose=False, query_only=False): + """ + Extract all audio tracks from a source video into temporary AAC files. + + Returns: + Tuple: + - List of temp file paths for extracted audio tracks + - List of corresponding metadata dicts: + {'codec', 'sample_rate', 'channels', 'duration', 'language'} + where 'duration' is set to container duration (for consistency). + """ + probe = ffmpeg.probe(source_video) + audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] + container_duration = float(probe['format'].get('duration', 0.0)) + + if not audio_streams: + if query_only: return 0 + if verbose: print(f"No audio track found in {source_video}") + return [], [] + + if query_only: + return len(audio_streams) + + if verbose: + print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s") + + file_paths = [] + metadata = [] + + for i, stream in enumerate(audio_streams): + fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') + os.close(fd) + + file_paths.append(temp_path) + metadata.append({ + 'codec': stream.get('codec_name'), + 'sample_rate': int(stream.get('sample_rate', 0)), + 'channels': int(stream.get('channels', 0)), + 'duration': container_duration, + 'language': stream.get('tags', {}).get('language', None) + }) + + ffmpeg.input(source_video).output( + temp_path, + **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'} + ).overwrite_output().run(quiet=not verbose) + + return file_paths, metadata + + + +def combine_and_concatenate_video_with_audio_tracks( + save_path_tmp, video_path, + source_audio_tracks, new_audio_tracks, + source_audio_duration, audio_sampling_rate, + new_audio_from_start=False, + source_audio_metadata=None, + audio_bitrate='128k', + audio_codec='aac', + verbose = False +): + inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1 + metadata_args = [] + sources = source_audio_tracks or [] + news = new_audio_tracks or [] + + duplicate_source = len(sources) == 1 and len(news) > 1 + N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1 + + for i in range(N): + s = (sources[i] if i < len(sources) + else sources[0] if duplicate_source else None) + n = news[i] if len(news) == N else (news[0] if news else None) + + if source_audio_duration == 0: + if n: + inputs += ['-i', n] + filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}]') + idx += 1 + else: + filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}]') + else: + if s: + inputs += ['-i', s] + meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {} + needs_filter = ( + meta.get('codec') != audio_codec or + meta.get('sample_rate') != audio_sampling_rate or + meta.get('channels') != 1 or + meta.get('duration', 0) < source_audio_duration + ) + if needs_filter: + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + else: + filters.append( + f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + if lang := meta.get('language'): + metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}'] + idx += 1 + else: + filters.append( + f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]') + + if n: + inputs += ['-i', n] + start = '0' if new_audio_from_start else source_audio_duration + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}]') + filters.append(f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}]') + idx += 1 + else: + filters.append(f'[s{i}]apad=pad_dur=100[aout{i}]') + + maps += ['-map', f'[aout{i}]'] + + cmd = ['ffmpeg', '-y', *inputs, + '-filter_complex', ';'.join(filters), # ✅ Only change made + *maps, *metadata_args, + '-c:v', 'copy', + '-c:a', audio_codec, + '-b:a', audio_bitrate, + '-ar', str(audio_sampling_rate), + '-ac', '1', + '-shortest', save_path_tmp] + + if verbose: + print(f"ffmpeg command: {cmd}") + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + raise Exception(f"FFmpeg error: {e.stderr}") + + +def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, + audio_metadata=None, verbose=False): + if not audio_tracks: + if verbose: print("No audio tracks to combine."); return False + + dur = float(next(s for s in ffmpeg.probe(target_video)['streams'] + if s['codec_type'] == 'video')['duration']) + if verbose: print(f"Video duration: {dur:.3f}s") + + cmd = ['ffmpeg', '-y', '-i', target_video] + for path in audio_tracks: + cmd += ['-i', path] + + cmd += ['-map', '0:v'] + for i in range(len(audio_tracks)): + cmd += ['-map', f'{i+1}:a'] + + for i, meta in enumerate(audio_metadata or []): + if (lang := meta.get('language')): + cmd += ['-metadata:s:a:' + str(i), f'language={lang}'] + + cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video] + + result = subprocess.run(cmd, capture_output=not verbose, text=True) + if result.returncode != 0: + raise Exception(f"FFmpeg error:\n{result.stderr}") + if verbose: + print(f"Created {output_video} with {len(audio_tracks)} audio track(s)") + return True + + +def cleanup_temp_audio_files(audio_tracks, verbose=False): + """ + Clean up temporary audio files. + + Args: + audio_tracks: List of audio file paths to delete + verbose: Enable verbose output (default: False) + + Returns: + Number of files successfully deleted + """ + deleted_count = 0 + + for audio_path in audio_tracks: + try: + if os.path.exists(audio_path): + os.unlink(audio_path) + deleted_count += 1 + if verbose: + print(f"Cleaned up {audio_path}") + except PermissionError: + print(f"Warning: Could not delete {audio_path} (file may be in use)") + except Exception as e: + print(f"Warning: Error deleting {audio_path}: {e}") + + if verbose and deleted_count > 0: + print(f"Successfully deleted {deleted_count} temporary audio file(s)") + + return deleted_count + + +def save_video(tensor, + save_file=None, + fps=30, + codec_type='libx264_8', + container='mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + """Save tensor as video with configurable codec and container options.""" + + suffix = f'.{container}' + cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file + if not cache_file.endswith(suffix): + cache_file = osp.splitext(cache_file)[0] + suffix + + # Configure codec parameters + codec_params = _get_codec_params(codec_type, container) + + # Process and save + error = None + for _ in range(retry): + try: + if torch.is_tensor(tensor): + # Preprocess tensor + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + arrays = tensor.numpy() + else: + arrays = tensor + + # Write video (silence ffmpeg logs) + writer = imageio.get_writer(cache_file, fps=fps, ffmpeg_log_level='error', **codec_params) + for frame in arrays: + writer.append_data(frame) + + writer.close() + return cache_file + + except Exception as e: + error = e + print(f"error saving {save_file}: {e}") + + +def _get_codec_params(codec_type, container): + """Get codec parameters based on codec type and container.""" + if codec_type == 'libx264_8': + return {'codec': 'libx264', 'quality': 8, 'pixelformat': 'yuv420p'} + elif codec_type == 'libx264_10': + return {'codec': 'libx264', 'quality': 10, 'pixelformat': 'yuv420p'} + elif codec_type == 'libx265_28': + return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '28', '-x265-params', 'log-level=none','-hide_banner', '-nostats']} + elif codec_type == 'libx265_8': + return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '8', '-x265-params', 'log-level=none','-hide_banner', '-nostats']} + elif codec_type == 'libx264_lossless': + if container == 'mkv': + return {'codec': 'ffv1', 'pixelformat': 'rgb24'} + else: # mp4 + return {'codec': 'libx264', 'output_params': ['-crf', '0'], 'pixelformat': 'yuv444p'} + else: # libx264 + return {'codec': 'libx264', 'pixelformat': 'yuv420p'} + + + + +def save_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + quality='jpeg_95', # 'jpeg_95', 'jpeg_85', 'jpeg_70', 'jpeg_50', 'webp_95', 'webp_85', 'webp_70', 'webp_50', 'png', 'webp_lossless' + retry=5): + """Save tensor as image with configurable format and quality.""" + + # Get format and quality settings + format_info = _get_format_info(quality) + + # Rename file extension to match requested format + save_file = osp.splitext(save_file)[0] + format_info['ext'] + + # Save image + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + + if format_info['use_pil']: + # Use PIL for WebP and advanced options + grid = torchvision.utils.make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range) + # Convert to PIL Image + grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + img = Image.fromarray(grid) + img.save(save_file, **format_info['params']) + else: + # Use torchvision for JPEG and PNG + torchvision.utils.save_image( + tensor, save_file, nrow=nrow, normalize=normalize, + value_range=value_range, **format_info['params'] + ) + break + except Exception as e: + error = e + continue + else: + print(f'cache_image failed, error: {error}', flush=True) + + return save_file + + +def _get_format_info(quality): + """Get format extension and parameters.""" + formats = { + # JPEG with PIL (so 'quality' works) + 'jpeg_95': {'ext': '.jpg', 'params': {'quality': 95}, 'use_pil': True}, + 'jpeg_85': {'ext': '.jpg', 'params': {'quality': 85}, 'use_pil': True}, + 'jpeg_70': {'ext': '.jpg', 'params': {'quality': 70}, 'use_pil': True}, + 'jpeg_50': {'ext': '.jpg', 'params': {'quality': 50}, 'use_pil': True}, + + # PNG with torchvision + 'png': {'ext': '.png', 'params': {}, 'use_pil': False}, + + # WebP with PIL (for quality control) + 'webp_95': {'ext': '.webp', 'params': {'quality': 95}, 'use_pil': True}, + 'webp_85': {'ext': '.webp', 'params': {'quality': 85}, 'use_pil': True}, + 'webp_70': {'ext': '.webp', 'params': {'quality': 70}, 'use_pil': True}, + 'webp_50': {'ext': '.webp', 'params': {'quality': 50}, 'use_pil': True}, + 'webp_lossless': {'ext': '.webp', 'params': {'lossless': True}, 'use_pil': True}, + } + return formats.get(quality, formats['jpeg_95']) + + +from PIL import Image, PngImagePlugin + +def _enc_uc(s): + try: return b"ASCII\0\0\0" + s.encode("ascii") + except UnicodeEncodeError: return b"UNICODE\0" + s.encode("utf-16le") + +def _dec_uc(b): + if not isinstance(b, (bytes, bytearray)): + try: b = bytes(b) + except Exception: return None + if b.startswith(b"ASCII\0\0\0"): return b[8:].decode("ascii", "ignore") + if b.startswith(b"UNICODE\0"): return b[8:].decode("utf-16le", "ignore") + return b.decode("utf-8", "ignore") + +def save_image_metadata(image_path, metadata_dict, **save_kwargs): + try: + j = json.dumps(metadata_dict, ensure_ascii=False) + ext = os.path.splitext(image_path)[1].lower() + with Image.open(image_path) as im: + if ext == ".png": + pi = PngImagePlugin.PngInfo(); pi.add_text("comment", j) + im.save(image_path, pnginfo=pi, **save_kwargs); return True + if ext in (".jpg", ".jpeg"): + im.save(image_path, comment=j.encode("utf-8"), **save_kwargs); return True + if ext == ".webp": + import piexif + exif = {"0th":{}, "Exif":{piexif.ExifIFD.UserComment:_enc_uc(j)}, "GPS":{}, "1st":{}, "thumbnail":None} + im.save(image_path, format="WEBP", exif=piexif.dump(exif), **save_kwargs); return True + raise ValueError("Unsupported format") + except Exception as e: + print(f"Error saving metadata: {e}"); return False + +def read_image_metadata(image_path): + try: + ext = os.path.splitext(image_path)[1].lower() + with Image.open(image_path) as im: + if ext == ".png": + val = (getattr(im, "text", {}) or {}).get("comment") or im.info.get("comment") + return json.loads(val) if val else None + if ext in (".jpg", ".jpeg"): + val = im.info.get("comment") + if isinstance(val, (bytes, bytearray)): val = val.decode("utf-8", "ignore") + if val: + try: return json.loads(val) + except Exception: pass + exif = getattr(im, "getexif", lambda: None)() + if exif: + uc = exif.get(37510) # UserComment + s = _dec_uc(uc) if uc else None + if s: + try: return json.loads(s) + except Exception: pass + return None + if ext == ".webp": + exif_bytes = Image.open(image_path).info.get("exif") + if not exif_bytes: return None + import piexif + uc = piexif.load(exif_bytes).get("Exif", {}).get(piexif.ExifIFD.UserComment) + s = _dec_uc(uc) if uc else None + return json.loads(s) if s else None + return None + except Exception as e: + print(f"Error reading metadata: {e}"); return None \ No newline at end of file diff --git a/wan/utils/basic_flowmatch.py b/shared/utils/basic_flowmatch.py similarity index 100% rename from wan/utils/basic_flowmatch.py rename to shared/utils/basic_flowmatch.py diff --git a/wan/utils/cammmaster_tools.py b/shared/utils/cammmaster_tools.py similarity index 97% rename from wan/utils/cammmaster_tools.py rename to shared/utils/cammmaster_tools.py index 6e255a0..b93ebba 100644 --- a/wan/utils/cammmaster_tools.py +++ b/shared/utils/cammmaster_tools.py @@ -40,7 +40,7 @@ def get_relative_pose(cam_params): def get_camera_embedding(cam_type, num_frames=81): # load camera - tgt_camera_path = "wan/camera_extrinsics.json" + tgt_camera_path = "models/wan/camera_extrinsics.json" with open(tgt_camera_path, 'r') as file: cam_data = json.load(file) diff --git a/wan/utils/fm_solvers.py b/shared/utils/fm_solvers.py similarity index 100% rename from wan/utils/fm_solvers.py rename to shared/utils/fm_solvers.py diff --git a/wan/utils/fm_solvers_unipc.py b/shared/utils/fm_solvers_unipc.py similarity index 100% rename from wan/utils/fm_solvers_unipc.py rename to shared/utils/fm_solvers_unipc.py diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py new file mode 100644 index 0000000..58cc9a9 --- /dev/null +++ b/shared/utils/loras_mutipliers.py @@ -0,0 +1,126 @@ +def preparse_loras_multipliers(loras_multipliers): + if isinstance(loras_multipliers, list): + return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers] + + loras_multipliers = loras_multipliers.strip(" \r\n") + loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") + loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] + loras_multipliers = " ".join(loras_mult_choices_list) + return loras_multipliers.split(" ") + +def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ): + def expand_one(slist, num_inference_steps): + if not isinstance(slist, list): slist = [slist] + new_slist= [] + if num_inference_steps <=0: + return new_slist + inc = len(slist) / num_inference_steps + pos = 0 + for i in range(num_inference_steps): + new_slist.append(slist[ int(pos)]) + pos += inc + return new_slist + + phase1 = slists_dict["phase1"][mult_no] + phase2 = slists_dict["phase2"][mult_no] + phase3 = slists_dict["phase3"][mult_no] + shared = slists_dict["shared"][mult_no] + if shared: + if isinstance(phase1, float): return phase1 + return expand_one(phase1, num_inference_steps) + else: + if isinstance(phase1, float) and isinstance(phase2, float) and isinstance(phase3, float) and phase1 == phase2 and phase2 == phase3: return phase1 + return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2) + +def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None): + if model_switch_step is None: + model_switch_step = num_inference_steps + if model_switch_step2 is None: + model_switch_step2 = num_inference_steps + def is_float(element: any) -> bool: + if element is None: + return False + try: + float(element) + return True + except ValueError: + return False + loras_list_mult_choices_nums = [] + slists_dict = { "model_switch_step": model_switch_step} + slists_dict = { "model_switch_step2": model_switch_step2} + slists_dict["phase1"] = phase1 = [1.] * nb_loras + slists_dict["phase2"] = phase2 = [1.] * nb_loras + slists_dict["phase3"] = phase3 = [1.] * nb_loras + slists_dict["shared"] = shared = [False] * nb_loras + + if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: + list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras] + for i, mult in enumerate(list_mult_choices_list): + current_phase = phase1 + if isinstance(mult, str): + mult = mult.strip() + phase_mult = mult.split(";") + shared_phases = len(phase_mult) <=1 + if not shared_phases and len(phase_mult) != nb_phases : + return "", "", f"if the ';' syntax is used for one Lora multiplier, the multipliers for its {nb_phases} denoising phases should be specified for this multiplier" + for phase_no, mult in enumerate(phase_mult): + if phase_no == 1: + current_phase = phase2 + elif phase_no == 2: + current_phase = phase3 + if "," in mult: + multlist = mult.split(",") + slist = [] + for smult in multlist: + if not is_float(smult): + return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid in Phase {phase_no+1}" + slist.append(float(smult)) + else: + if not is_float(mult): + return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid" + slist = float(mult) + if shared_phases: + phase1[i] = phase2[i] = phase3[i] = slist + shared[i] = True + else: + current_phase[i] = slist + else: + phase1[i] = phase2[i] = phase3[i] = float(mult) + shared[i] = True + + if merge_slist is not None: + slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1 + slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2 + slists_dict["phase3"] = phase3 = merge_slist["phase3"] + phase3 + slists_dict["shared"] = shared = merge_slist["shared"] + shared + + loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step, model_switch_step2 ) for i in range(len(phase1)) ] + loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ] + + return loras_list_mult_choices_nums, slists_dict, "" + +def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ): + from mmgp import offload + sz = len(slists_dict["phase1"]) + slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] + nos = [str(l) for l in range(sz)] + offload.activate_loras(trans, nos, slists ) + + + +def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): + model_switch_step = model_switch_step2 = None + for i, t in enumerate(timesteps): + if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i + if guide_phases >=3 and model_switch_step2 is None and t <= switch2_threshold: model_switch_step2 = i + if model_switch_step is None: model_switch_step = total_num_steps + if model_switch_step2 is None: model_switch_step2 = total_num_steps + phases_description = "" + if guide_phases > 1: + phases_description = "Denoising Steps: " + phases_description += f" Phase 1 = None" if model_switch_step == 0 else f" Phase 1 = 1:{ min(model_switch_step,total_num_steps) }" + if model_switch_step < total_num_steps: + phases_description += f", Phase 2 = None" if model_switch_step == model_switch_step2 else f", Phase 2 = {model_switch_step +1}:{ min(model_switch_step2,total_num_steps) }" + if guide_phases > 2 and model_switch_step2 < total_num_steps: + phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}" + return model_switch_step, model_switch_step2, phases_description diff --git a/wan/utils/motion.py b/shared/utils/motion.py similarity index 100% rename from wan/utils/motion.py rename to shared/utils/motion.py diff --git a/wan/utils/notification_sound.py b/shared/utils/notification_sound.py similarity index 52% rename from wan/utils/notification_sound.py rename to shared/utils/notification_sound.py index 47d2e3d..26d1966 100644 --- a/wan/utils/notification_sound.py +++ b/shared/utils/notification_sound.py @@ -9,15 +9,21 @@ import threading import time import numpy as np +os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" -def generate_notification_beep(volume=50, sample_rate=44100): +_cached_waveforms = {} +_sample_rate = 44100 +_mixer_initialized = False +_mixer_lock = threading.Lock() + +def _generate_notification_beep(volume=50, sample_rate=_sample_rate): """Generate pleasant C major chord notification sound""" if volume == 0: return np.array([]) - + volume = max(0, min(100, volume)) - - # Volume curve mapping: 25%->50%, 50%->75%, 75%->100%, 100%->105% + + # Volume curve mapping if volume <= 25: volume_mapped = (volume / 25.0) * 0.5 elif volume <= 50: @@ -25,211 +31,191 @@ def generate_notification_beep(volume=50, sample_rate=44100): elif volume <= 75: volume_mapped = 0.75 + ((volume - 50) / 25.0) * 0.25 else: - volume_mapped = 1.0 + ((volume - 75) / 25.0) * 0.05 # Only 5% boost instead of 15% - + volume_mapped = 1.0 + ((volume - 75) / 25.0) * 0.05 + volume = volume_mapped - + # C major chord frequencies - freq_c = 261.63 # C4 - freq_e = 329.63 # E4 - freq_g = 392.00 # G4 - + freq_c, freq_e, freq_g = 261.63, 329.63, 392.00 duration = 0.8 t = np.linspace(0, duration, int(sample_rate * duration), False) - + # Generate chord components - wave_c = np.sin(freq_c * 2 * np.pi * t) * 0.4 - wave_e = np.sin(freq_e * 2 * np.pi * t) * 0.3 - wave_g = np.sin(freq_g * 2 * np.pi * t) * 0.2 - - wave = wave_c + wave_e + wave_g - - # Prevent clipping + wave = ( + np.sin(freq_c * 2 * np.pi * t) * 0.4 + + np.sin(freq_e * 2 * np.pi * t) * 0.3 + + np.sin(freq_g * 2 * np.pi * t) * 0.2 + ) + + # Normalize max_amplitude = np.max(np.abs(wave)) if max_amplitude > 0: wave = wave / max_amplitude * 0.8 - + # ADSR envelope def apply_adsr_envelope(wave_data): length = len(wave_data) attack_time = int(0.2 * length) decay_time = int(0.1 * length) release_time = int(0.5 * length) - + envelope = np.ones(length) - + if attack_time > 0: envelope[:attack_time] = np.power(np.linspace(0, 1, attack_time), 3) - + if decay_time > 0: - start_idx = attack_time - end_idx = attack_time + decay_time + start_idx, end_idx = attack_time, attack_time + decay_time envelope[start_idx:end_idx] = np.linspace(1, 0.85, decay_time) - + if release_time > 0: start_idx = length - release_time envelope[start_idx:] = 0.85 * np.exp(-4 * np.linspace(0, 1, release_time)) - + return wave_data * envelope - + wave = apply_adsr_envelope(wave) - + # Simple low-pass filter def simple_lowpass_filter(signal, cutoff_ratio=0.8): window_size = max(3, int(len(signal) * 0.001)) if window_size % 2 == 0: window_size += 1 - + kernel = np.ones(window_size) / window_size - padded = np.pad(signal, window_size//2, mode='edge') - filtered = np.convolve(padded, kernel, mode='same') - return filtered[window_size//2:-window_size//2] - + padded = np.pad(signal, window_size // 2, mode="edge") + filtered = np.convolve(padded, kernel, mode="same") + return filtered[window_size // 2 : -window_size // 2] + wave = simple_lowpass_filter(wave) - - # Add reverb effect + + # Add reverb if len(wave) > sample_rate // 4: delay_samples = int(0.12 * sample_rate) reverb = np.zeros_like(wave) reverb[delay_samples:] = wave[:-delay_samples] * 0.08 wave = wave + reverb - - # Apply volume first, then normalize to prevent clipping + + # Apply volume & final normalize wave = wave * volume * 0.5 - - # Final normalization with safety margin max_amplitude = np.max(np.abs(wave)) - if max_amplitude > 0.85: # If approaching clipping threshold - wave = wave / max_amplitude * 0.85 # More conservative normalization - + if max_amplitude > 0.85: + wave = wave / max_amplitude * 0.85 + return wave +def _get_cached_waveform(volume): + """Return cached waveform for volume""" + if volume not in _cached_waveforms: + _cached_waveforms[volume] = _generate_notification_beep(volume) + return _cached_waveforms[volume] -def play_audio_with_pygame(audio_data, sample_rate=44100): - """Play audio using pygame backend""" + +def play_audio_with_pygame(audio_data, sample_rate=_sample_rate): + """Play audio with pygame backend""" + global _mixer_initialized try: import pygame - # Initialize pygame mixer only if not already initialized - if not pygame.mixer.get_init(): - pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024) - pygame.mixer.init() - else: - # Reinitialize with new settings if needed - current_freq, current_size, current_channels = pygame.mixer.get_init() - if current_freq != sample_rate or current_channels != 2: - pygame.mixer.quit() - pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024) + + with _mixer_lock: + if not _mixer_initialized: + pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=512) pygame.mixer.init() - - audio_int16 = (audio_data * 32767).astype(np.int16) - - # Convert mono to stereo - if len(audio_int16.shape) == 1: - stereo_data = np.column_stack((audio_int16, audio_int16)) - else: - stereo_data = audio_int16 - - sound = pygame.sndarray.make_sound(stereo_data) - sound.play() - pygame.time.wait(int(len(audio_data) / sample_rate * 1000) + 100) - # Don't quit mixer - this can interfere with Gradio server - # pygame.mixer.quit() - return True - + _mixer_initialized = True + + mixer_info = pygame.mixer.get_init() + if mixer_info is None or mixer_info[2] != 2: + return False + + audio_int16 = (audio_data * 32767).astype(np.int16) + if len(audio_int16.shape) > 1: + audio_int16 = audio_int16.flatten() + + stereo_data = np.zeros((len(audio_int16), 2), dtype=np.int16) + stereo_data[:, 0] = audio_int16 + stereo_data[:, 1] = audio_int16 + + sound = pygame.sndarray.make_sound(stereo_data) + pygame.mixer.stop() + sound.play() + + duration_ms = int(len(audio_data) / sample_rate * 1000) + 50 + pygame.time.wait(duration_ms) + + return True + except ImportError: return False except Exception as e: print(f"Pygame error: {e}") return False - -def play_audio_with_sounddevice(audio_data, sample_rate=44100): +def play_audio_with_sounddevice(audio_data, sample_rate=_sample_rate): """Play audio using sounddevice backend""" try: import sounddevice as sd sd.play(audio_data, sample_rate) sd.wait() return True - except ImportError: return False except Exception as e: print(f"Sounddevice error: {e}") return False - -def play_audio_with_winsound(audio_data, sample_rate=44100): +def play_audio_with_winsound(audio_data, sample_rate=_sample_rate): """Play audio using winsound backend (Windows only)""" if sys.platform != "win32": return False - try: - import winsound - import wave - import tempfile - import uuid - + import winsound, wave, tempfile, uuid + temp_dir = tempfile.gettempdir() temp_filename = os.path.join(temp_dir, f"notification_{uuid.uuid4().hex}.wav") - + try: - with wave.open(temp_filename, 'w') as wav_file: + with wave.open(temp_filename, "w") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) - audio_int16 = (audio_data * 32767).astype(np.int16) wav_file.writeframes(audio_int16.tobytes()) - + winsound.PlaySound(temp_filename, winsound.SND_FILENAME) - + finally: - # Clean up temp file - for _ in range(3): - try: - if os.path.exists(temp_filename): - os.unlink(temp_filename) - break - except: - time.sleep(0.1) - + try: + if os.path.exists(temp_filename): + os.unlink(temp_filename) + except: + pass + return True - except ImportError: return False except Exception as e: print(f"Winsound error: {e}") return False - def play_notification_sound(volume=50): """Play notification sound with specified volume""" if volume == 0: return - - audio_data = generate_notification_beep(volume=volume) - + + audio_data = _get_cached_waveform(volume) if len(audio_data) == 0: return - - # Try audio backends in order - audio_backends = [ - play_audio_with_pygame, - play_audio_with_sounddevice, - play_audio_with_winsound, - ] - + + audio_backends = [play_audio_with_pygame, play_audio_with_sounddevice, play_audio_with_winsound] for backend in audio_backends: try: if backend(audio_data): return - except Exception as e: + except Exception: continue - - # Fallback: terminal beep - print(f"All audio backends failed, using terminal beep") - print('\a') + print("All audio backends failed, using terminal beep") + print("\a") def play_notification_async(volume=50): """Play notification sound asynchronously (non-blocking)""" @@ -238,24 +224,12 @@ def play_notification_async(volume=50): play_notification_sound(volume) except Exception as e: print(f"Error playing notification sound: {e}") - - sound_thread = threading.Thread(target=play_sound, daemon=True) - sound_thread.start() + threading.Thread(target=play_sound, daemon=True).start() def notify_video_completion(video_path=None, volume=50): """Notify about completed video generation""" play_notification_async(volume) - -if __name__ == "__main__": - print("Testing notification sounds with different volumes...") - print("Auto-detecting available audio backends...") - - volumes = [25, 50, 75, 100] - for vol in volumes: - print(f"Testing volume {vol}%:") - play_notification_sound(vol) - time.sleep(2) - - print("Test completed!") \ No newline at end of file +for vol in (25, 50, 75, 100): + _get_cached_waveform(vol) \ No newline at end of file diff --git a/wan/utils/prompt_extend.py b/shared/utils/prompt_extend.py similarity index 100% rename from wan/utils/prompt_extend.py rename to shared/utils/prompt_extend.py diff --git a/wan/utils/prompt_parser.py b/shared/utils/prompt_parser.py similarity index 94% rename from wan/utils/prompt_parser.py rename to shared/utils/prompt_parser.py index faaa1ca..46edec4 100644 --- a/wan/utils/prompt_parser.py +++ b/shared/utils/prompt_parser.py @@ -1,6 +1,6 @@ import re -def process_template(input_text): +def process_template(input_text, keep_comments = False): """ Process a text template with macro instructions and variable substitution. Supports multiple values for variables to generate multiple output versions. @@ -28,9 +28,12 @@ def process_template(input_text): line_number += 1 # Skip empty lines or comments - if not line or line.startswith('#'): + if not line: continue - + + if line.startswith('#') and not keep_comments: + continue + # Handle macro instructions if line.startswith('!'): # Process any accumulated template lines before starting a new macro @@ -106,13 +109,14 @@ def process_template(input_text): # Handle template lines else: - # Check for unknown variables in template line - var_references = re.findall(r'\{([^}]+)\}', line) - for var_ref in var_references: - if var_ref not in current_variables: - error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'" - return "", error_message - + if not line.startswith('#'): + # Check for unknown variables in template line + var_references = re.findall(r'\{([^}]+)\}', line) + for var_ref in var_references: + if var_ref not in current_variables: + error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'" + return "", error_message + # Add to current template lines current_template_lines.append(line) diff --git a/wan/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py similarity index 100% rename from wan/utils/qwen_vl_utils.py rename to shared/utils/qwen_vl_utils.py diff --git a/shared/utils/stats.py b/shared/utils/stats.py new file mode 100644 index 0000000..2a94b33 --- /dev/null +++ b/shared/utils/stats.py @@ -0,0 +1,256 @@ +import gradio as gr +import signal +import sys +import time +import threading +import atexit +from contextlib import contextmanager +from collections import deque +import psutil +import pynvml + +# Initialize NVIDIA Management Library (NVML) for GPU monitoring +try: + pynvml.nvmlInit() + nvml_initialized = True +except pynvml.NVMLError: + print("Warning: Could not initialize NVML. GPU stats will not be available.") + nvml_initialized = False + +class SystemStatsApp: + def __init__(self): + self.running = False + self.active_generators = [] + self.setup_signal_handlers() + + def setup_signal_handlers(self): + # Handle different shutdown signals + signal.signal(signal.SIGINT, self.shutdown_handler) + signal.signal(signal.SIGTERM, self.shutdown_handler) + if hasattr(signal, 'SIGBREAK'): # Windows + signal.signal(signal.SIGBREAK, self.shutdown_handler) + + # Also register atexit handler as backup + atexit.register(self.cleanup) + + def shutdown_handler(self, signum, frame): + # print(f"\nReceived signal {signum}. Shutting down gracefully...") + self.cleanup() + sys.exit(0) + + def cleanup(self): + if not self.running: + print("Cleaning up streaming connections...") + self.running = False + # Give a moment for generators to stop + time.sleep(1) + + def get_system_stats(self, first = False, last_disk_io = psutil.disk_io_counters() ): + + # Set a reasonable maximum speed for the bar graph display. + # 100 MB/s will represent a 100% full bar. + MAX_SSD_SPEED_MB_S = 100.0 + # Get CPU and RAM stats + if first : + cpu_percent = psutil.cpu_percent(interval=.01) + else: + cpu_percent = psutil.cpu_percent(interval=1) # This provides our 1-second delay + memory_info = psutil.virtual_memory() + ram_percent = memory_info.percent + ram_used_gb = memory_info.used / (1024**3) + ram_total_gb = memory_info.total / (1024**3) + + # Get new disk IO counters and calculate the read/write speed in MB/s + current_disk_io = psutil.disk_io_counters() + read_mb_s = (current_disk_io.read_bytes - last_disk_io.read_bytes) / (1024**2) + write_mb_s = (current_disk_io.write_bytes - last_disk_io.write_bytes) / (1024**2) + total_disk_speed = read_mb_s + write_mb_s + + # Update the last counters for the next loop + last_disk_io = current_disk_io + + # Calculate the bar height as a percentage of our defined max speed + ssd_bar_height = min(100.0, (total_disk_speed / MAX_SSD_SPEED_MB_S) * 100) + + # Get GPU stats if the library was initialized successfully + if nvml_initialized: + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0 + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + gpu_percent = util.gpu + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + vram_percent = (mem_info.used / mem_info.total) * 100 + vram_used_gb = mem_info.used / (1024**3) + vram_total_gb = mem_info.total / (1024**3) + except pynvml.NVMLError: + # Handle cases where GPU might be asleep or driver issues + gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0 + else: + # Set default values if NVML failed to load + gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0 + + stats_html = f""" + + +
+ +
+
+
+
+
CPU: {cpu_percent:.1f}%
+
+ + +
+
+
+
+
RAM {ram_percent:.1f}%
+
{ram_used_gb:.1f} / {ram_total_gb:.1f} GB
+
+ + +
+
+
+
+
SSD R/W
+
{read_mb_s:.1f} / {write_mb_s:.1f} MB/s
+
+ + +
+
+
+
+
GPU: {gpu_percent:.1f}%
+
+ + +
+
+
+
+
VRAM {vram_percent:.1f}%
+
{vram_used_gb:.1f} / {vram_total_gb:.1f} GB
+
+
+ """ + return stats_html, last_disk_io + + def streaming_html(self, state): + if "stats_running" in state: + return + state["stats_running"] = True + + self.running = True + last_disk_io = psutil.disk_io_counters() + i = 0 + import time + try: + while self.running: + i+= 1 + # if i % 2 == 0: + # print(f"time:{time.time()}") + html_content, last_disk_io = self.get_system_stats(False, last_disk_io) + yield html_content + # time.sleep(1) + + except GeneratorExit: + # print("Generator stopped gracefully") + return + except Exception as e: + print(f"Streaming error: {e}") + # finally: + # # Send final message indicating clean shutdown + final_html = """ +
+ + + +
+ """ + try: + yield final_html + except: + pass + + + def get_gradio_element(self): + self.system_stats_display = gr.HTML(self.get_system_stats(True)[0]) + self.restart_btn = gr.Button("restart stats",elem_id="restart_stats", visible= False) # False) + return self.system_stats_display + + def setup_events(self, main, state): + gr.on([main.load, self.restart_btn.click], + fn=self.streaming_html, + inputs = state, + outputs=self.system_stats_display, + show_progress=False + ) diff --git a/wan/utils/thread_utils.py b/shared/utils/thread_utils.py similarity index 100% rename from wan/utils/thread_utils.py rename to shared/utils/thread_utils.py diff --git a/wan/utils/utils.py b/shared/utils/utils.py similarity index 59% rename from wan/utils/utils.py rename to shared/utils/utils.py index 53f3b73..a55807a 100644 --- a/wan/utils/utils.py +++ b/shared/utils/utils.py @@ -1,6 +1,5 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import binascii import os import os.path as osp import torchvision.transforms.functional as TF @@ -10,7 +9,6 @@ import tempfile import imageio import torch import decord -import torchvision from PIL import Image import numpy as np from rembg import remove, new_session @@ -18,8 +16,8 @@ import random import ffmpeg import os import tempfile - -__all__ = ['cache_video', 'cache_image', 'str2bool'] +import subprocess +import json @@ -34,21 +32,6 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) -def expand_slist(slist, num_inference_steps ): - new_slist= [] - inc = len(slist) / num_inference_steps - pos = 0 - for i in range(num_inference_steps): - new_slist.append(slist[ int(pos)]) - pos += inc - return new_slist - -def update_loras_slists(trans, slists, num_inference_steps ): - from mmgp import offload - slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] - nos = [str(l) for l in range(len(slists))] - offload.activate_loras(trans, nos, slists ) - def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -99,7 +82,7 @@ def get_video_info(video_path): cap = cv2.VideoCapture(video_path) # Get FPS - fps = cap.get(cv2.CAP_PROP_FPS) + fps = round(cap.get(cv2.CAP_PROP_FPS)) # Get resolution width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) @@ -109,13 +92,44 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name, frame_no): - decord.bridge.set_bridge('torch') - reader = decord.VideoReader(file_name) +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: + """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" + cap = cv2.VideoCapture(file_name) + + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {file_name}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Handle out of bounds + if frame_no >= total_frames or frame_no < 0: + if return_last_if_missing: + frame_no = total_frames - 1 + else: + cap.release() + raise IndexError(f"Frame {frame_no} out of bounds (0-{total_frames-1})") + + # Get frame + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) + ret, frame = cap.read() + cap.release() + + if not ret: + raise ValueError(f"Failed to read frame {frame_no}") + + # Convert BGR->RGB, reshape to (C,H,W), normalize to [-1,1] + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if return_PIL: + return Image.fromarray(frame) + else: + return (torch.from_numpy(frame).permute(2, 0, 1).float() / 127.5) - 1.0 +# def get_video_frame(file_name, frame_no): +# decord.bridge.set_bridge('torch') +# reader = decord.VideoReader(file_name) - frame = reader.get_batch([frame_no]).squeeze(0) - img = Image.fromarray(frame.numpy().astype(np.uint8)) - return img +# frame = reader.get_batch([frame_no]).squeeze(0) +# img = Image.fromarray(frame.numpy().astype(np.uint8)) +# return img def convert_image_to_video(image): if image is None: @@ -141,10 +155,12 @@ def convert_image_to_video(image): return temp_video.name def resize_lanczos(img, h, w): - img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) + img = (img + 1).float().mul_(127.5) + img = Image.fromarray(np.clip(img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) - return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) - + img = torch.from_numpy(np.array(img).astype(np.float32)).movedim(-1, 0) + img = img.div(127.5).sub_(1) + return img def remove_background(img, session=None): if session ==None: @@ -153,6 +169,10 @@ def remove_background(img, session=None): img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) + +def convert_image_to_tensor(image): + return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) + def convert_tensor_to_image(t, frame_no = -1): t = t[:, frame_no] if frame_no >= 0 else t return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) @@ -185,18 +205,19 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width return height, width, margin_top, margin_left -def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16): +def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): if fit_into_canvas == None: - return height, width + # return image_height, image_width + return canvas_height, canvas_width if fit_into_canvas: - scale1 = min(canvas_height / height, canvas_width / width) - scale2 = min(canvas_width / height, canvas_height / width) + scale1 = min(canvas_height / image_height, canvas_width / image_width) + scale2 = min(canvas_width / image_height, canvas_height / image_width) scale = max(scale1, scale2) else: - scale = (canvas_height * canvas_width / (height * width))**(1/2) + scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) - new_height = round( height * scale / block_size) * block_size - new_width = round( width * scale / block_size) * block_size + new_height = round( image_height * scale / block_size) * block_size + new_width = round( image_width * scale / block_size) * block_size return new_height, new_width def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): @@ -229,84 +250,6 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg return output_list -def rand_name(length=8, suffix=''): - name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') - if suffix: - if not suffix.startswith('.'): - suffix = '.' + suffix - name += suffix - return name - - -def cache_video(tensor, - save_file=None, - fps=30, - suffix='.mp4', - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): - # cache file - cache_file = osp.join('/tmp', rand_name( - suffix=suffix)) if save_file is None else save_file - - # save to cache - error = None - for _ in range(retry): - try: - # preprocess - tensor = tensor.clamp(min(value_range), max(value_range)) - tensor = torch.stack([ - torchvision.utils.make_grid( - u, nrow=nrow, normalize=normalize, value_range=value_range) - for u in tensor.unbind(2) - ], - dim=1).permute(1, 2, 3, 0) - tensor = (tensor * 255).type(torch.uint8).cpu() - - # write video - writer = imageio.get_writer( - cache_file, fps=fps, codec='libx264', quality=8) - for frame in tensor.numpy(): - writer.append_data(frame) - writer.close() - return cache_file - except Exception as e: - error = e - continue - else: - print(f'cache_video failed, error: {error}', flush=True) - return None - - -def cache_image(tensor, - save_file, - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): - # cache file - suffix = osp.splitext(save_file)[1] - if suffix.lower() not in [ - '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' - ]: - suffix = '.png' - - # save to cache - error = None - for _ in range(retry): - try: - tensor = tensor.clamp(min(value_range), max(value_range)) - torchvision.utils.save_image( - tensor, - save_file, - nrow=nrow, - normalize=normalize, - value_range=value_range) - return save_file - except Exception as e: - error = e - continue def str2bool(v): @@ -445,137 +388,4 @@ def create_progress_hook(filename): return progress_hook(block_num, block_size, total_size, filename) return hook -import ffmpeg -import os -import tempfile - -def extract_audio_tracks(source_video, verbose=False, query_only= False): - """ - Extract all audio tracks from source video to temporary files. - - Args: - source_video: Path to video with audio to extract - verbose: Enable verbose output (default: False) - - Returns: - List of temporary audio file paths, or empty list if no audio tracks - """ - try: - # Check if source video has audio - probe = ffmpeg.probe(source_video) - audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] - - if not audio_streams: - if query_only: return 0 - if verbose: - print(f"No audio track found in {source_video}") - return [] - if query_only: return len(audio_streams) - if verbose: - print(f"Found {len(audio_streams)} audio track(s)") - - # Create temporary audio files for each track - temp_audio_files = [] - for i in range(len(audio_streams)): - fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') - os.close(fd) # Close file descriptor immediately - temp_audio_files.append(temp_path) - - # Extract each audio track - for i, temp_path in enumerate(temp_audio_files): - (ffmpeg - .input(source_video) - .output(temp_path, **{f'map': f'0:a:{i}', 'acodec': 'aac'}) - .overwrite_output() - .run(quiet=not verbose)) - - return temp_audio_files - - except ffmpeg.Error as e: - print(f"FFmpeg error during audio extraction: {e}") - return 0 if query_only else [] - except Exception as e: - print(f"Error during audio extraction: {e}") - return 0 if query_only else [] - -def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False): - """ - Combine video with audio tracks. Output duration matches video length exactly. - - Args: - target_video: Path to video to receive the audio - audio_tracks: List of audio file paths to combine - output_video: Path for the output video - verbose: Enable verbose output (default: False) - - Returns: - True if successful, False otherwise - """ - if not audio_tracks: - if verbose: - print("No audio tracks to combine") - return False - - try: - # Get video duration to ensure exact alignment - video_probe = ffmpeg.probe(target_video) - video_duration = float(video_probe['streams'][0]['duration']) - - if verbose: - print(f"Target video duration: {video_duration:.3f} seconds") - - # Combine target video with all audio tracks, force video duration - video = ffmpeg.input(target_video).video - audio_inputs = [ffmpeg.input(audio_path).audio for audio_path in audio_tracks] - - # Create output with video duration as master timing - inputs = [video] + audio_inputs - (ffmpeg - .output(*inputs, output_video, - vcodec='copy', - acodec='copy', - t=video_duration) # Force exact video duration - .overwrite_output() - .run(quiet=not verbose)) - - if verbose: - print(f"Successfully created {output_video} with {len(audio_tracks)} audio track(s) aligned to video duration") - return True - - except ffmpeg.Error as e: - print(f"FFmpeg error during video combination: {e}") - return False - except Exception as e: - print(f"Error during video combination: {e}") - return False - -def cleanup_temp_audio_files(audio_tracks, verbose=False): - """ - Clean up temporary audio files. - - Args: - audio_tracks: List of audio file paths to delete - verbose: Enable verbose output (default: False) - - Returns: - Number of files successfully deleted - """ - deleted_count = 0 - - for audio_path in audio_tracks: - try: - if os.path.exists(audio_path): - os.unlink(audio_path) - deleted_count += 1 - if verbose: - print(f"Cleaned up {audio_path}") - except PermissionError: - print(f"Warning: Could not delete {audio_path} (file may be in use)") - except Exception as e: - print(f"Warning: Error deleting {audio_path}: {e}") - - if verbose and deleted_count > 0: - print(f"Successfully deleted {deleted_count} temporary audio file(s)") - - return deleted_count diff --git a/wan/utils/vace_preprocessor.py b/shared/utils/vace_preprocessor.py similarity index 99% rename from wan/utils/vace_preprocessor.py rename to shared/utils/vace_preprocessor.py index 7fdb8c9..947767e 100644 --- a/wan/utils/vace_preprocessor.py +++ b/shared/utils/vace_preprocessor.py @@ -184,7 +184,7 @@ class VaceVideoProcessor(object): def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0): - from wan.utils.utils import resample + from shared.utils.utils import resample target_fps = self.max_fps diff --git a/wan/distributed/__init__.py b/wan/distributed/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py deleted file mode 100644 index 8462390..0000000 --- a/wan/multitalk/multitalk_utils.py +++ /dev/null @@ -1,353 +0,0 @@ -import os -from einops import rearrange - -import torch -import torch.nn as nn - -from einops import rearrange, repeat -from functools import lru_cache -import imageio -import uuid -from tqdm import tqdm -import numpy as np -import subprocess -import soundfile as sf -import torchvision -import binascii -import os.path as osp - - -VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") -ASPECT_RATIO_627 = { - '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), - '0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), - '1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), - '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} - - -ASPECT_RATIO_960 = { - '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), - '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), - '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), - '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), - '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), - '3.75': ([1920, 512], 1)} - - - -def torch_gc(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - - -def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): - - S = T * token_frame - split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] - start = sum(split_sizes[:rank]) - end = start + split_sizes[rank] - counts = [0] * T - for idx in range(start, end): - t = idx // token_frame - counts[t] += 1 - - counts_filtered = [] - frame_ids = [] - for t, c in enumerate(counts): - if c > 0: - counts_filtered.append(c) - frame_ids.append(t) - return counts_filtered, frame_ids - - -def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): - - source_min, source_max = source_range - new_min, new_max = target_range - - normalized = (column - source_min) / (source_max - source_min + epsilon) - scaled = normalized * (new_max - new_min) + new_min - return scaled - - -# @torch.compile -def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): - - ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) - scale = 1.0 / visual_q.shape[-1] ** 0.5 - visual_q = visual_q * scale - visual_q = visual_q.transpose(1, 2) - ref_k = ref_k.transpose(1, 2) - attn = visual_q @ ref_k.transpose(-2, -1) - - if attn_bias is not None: attn += attn_bias - - x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens - - x_ref_attn_maps = [] - ref_target_masks = ref_target_masks.to(visual_q.dtype) - x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) - - for class_idx, ref_target_mask in enumerate(ref_target_masks): - ref_target_mask = ref_target_mask[None, None, None, ...] - x_ref_attnmap = x_ref_attn_map_source * ref_target_mask - x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens - x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H - - if mode == 'mean': - x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens - elif mode == 'max': - x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens - - x_ref_attn_maps.append(x_ref_attnmap) - - del attn - del x_ref_attn_map_source - - return torch.concat(x_ref_attn_maps, dim=0) - - -def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): - """Args: - query (torch.tensor): B M H K - key (torch.tensor): B M H K - shape (tuple): (N_t, N_h, N_w) - ref_target_masks: [B, N_h * N_w] - """ - - N_t, N_h, N_w = shape - - x_seqlens = N_h * N_w - ref_k = ref_k[:, :x_seqlens] - if ref_images_count > 0 : - visual_q_shape = visual_q.shape - visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1) - visual_q = visual_q[:, ref_images_count:] - visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:]) - - _, seq_lens, heads, _ = visual_q.shape - class_num, _ = ref_target_masks.shape - x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device) - - split_chunk = heads // split_num - - for i in range(split_num): - x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) - x_ref_attn_maps += x_ref_attn_maps_perhead - - x_ref_attn_maps /= split_num - return x_ref_attn_maps - - -def rotate_half(x): - x = rearrange(x, "... (d r) -> ... d r", r=2) - x1, x2 = x.unbind(dim=-1) - x = torch.stack((-x2, x1), dim=-1) - return rearrange(x, "... d r -> ... (d r)") - - -class RotaryPositionalEmbedding1D(nn.Module): - - def __init__(self, - head_dim, - ): - super().__init__() - self.head_dim = head_dim - self.base = 10000 - - - @lru_cache(maxsize=32) - def precompute_freqs_cis_1d(self, pos_indices): - - freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) - freqs = freqs.to(pos_indices.device) - freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) - freqs = repeat(freqs, "... n -> ... (n r)", r=2) - return freqs - - def forward(self, x, pos_indices): - """1D RoPE. - - Args: - query (torch.tensor): [B, head, seq, head_dim] - pos_indices (torch.tensor): [seq,] - Returns: - query with the same shape as input. - """ - freqs_cis = self.precompute_freqs_cis_1d(pos_indices) - - x_ = x.float() - - freqs_cis = freqs_cis.float().to(x.device) - cos, sin = freqs_cis.cos(), freqs_cis.sin() - cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') - x_ = (x_ * cos) + (rotate_half(x_) * sin) - - return x_.type_as(x) - - - -def rand_name(length=8, suffix=''): - name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') - if suffix: - if not suffix.startswith('.'): - suffix = '.' + suffix - name += suffix - return name - -def cache_video(tensor, - save_file=None, - fps=30, - suffix='.mp4', - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): - - # cache file - cache_file = osp.join('/tmp', rand_name( - suffix=suffix)) if save_file is None else save_file - - # save to cache - error = None - for _ in range(retry): - - # preprocess - tensor = tensor.clamp(min(value_range), max(value_range)) - tensor = torch.stack([ - torchvision.utils.make_grid( - u, nrow=nrow, normalize=normalize, value_range=value_range) - for u in tensor.unbind(2) - ], - dim=1).permute(1, 2, 3, 0) - tensor = (tensor * 255).type(torch.uint8).cpu() - - # write video - writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"]) - for frame in tensor.numpy(): - writer.append_data(frame) - writer.close() - return cache_file - -def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): - - def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): - writer = imageio.get_writer( - save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params - ) - for frame in tqdm(frames, desc="Saving video"): - frame = np.array(frame) - writer.append_data(frame) - writer.close() - save_path_tmp = save_path + "-temp.mp4" - - if high_quality_save: - cache_video( - tensor=gen_video_samples.unsqueeze(0), - save_file=save_path_tmp, - fps=fps, - nrow=1, - normalize=True, - value_range=(-1, 1) - ) - else: - video_audio = (gen_video_samples+1)/2 # C T H W - video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() - video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] - save_video(video_audio, save_path_tmp, fps=fps, quality=quality) - - - # crop audio according to video length - _, T, _, _ = gen_video_samples.shape - duration = T / fps - save_path_crop_audio = save_path + "-cropaudio.wav" - final_command = [ - "ffmpeg", - "-i", - vocal_audio_list[0], - "-t", - f'{duration}', - save_path_crop_audio, - ] - subprocess.run(final_command, check=True) - - save_path = save_path + ".mp4" - if high_quality_save: - final_command = [ - "ffmpeg", - "-y", - "-i", save_path_tmp, - "-i", save_path_crop_audio, - "-c:v", "libx264", - "-crf", "0", - "-preset", "veryslow", - "-c:a", "aac", - "-shortest", - save_path, - ] - subprocess.run(final_command, check=True) - os.remove(save_path_tmp) - os.remove(save_path_crop_audio) - else: - final_command = [ - "ffmpeg", - "-y", - "-i", - save_path_tmp, - "-i", - save_path_crop_audio, - "-c:v", - "libx264", - "-c:a", - "aac", - "-shortest", - save_path, - ] - subprocess.run(final_command, check=True) - os.remove(save_path_tmp) - os.remove(save_path_crop_audio) - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - - -def project( - v0: torch.Tensor, # [B, C, T, H, W] - v1: torch.Tensor, # [B, C, T, H, W] - ): - dtype = v0.dtype - v0, v1 = v0.double(), v1.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - return v0_parallel.to(dtype), v0_orthogonal.to(dtype) - - -def adaptive_projected_guidance( - diff: torch.Tensor, # [B, C, T, H, W] - pred_cond: torch.Tensor, # [B, C, T, H, W] - momentum_buffer: MomentumBuffer = None, - eta: float = 0.0, - norm_threshold: float = 55, - ): - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) - print(f"diff_norm: {diff_norm}") - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - diff_parallel, diff_orthogonal = project(diff, pred_cond) - normalized_update = diff_orthogonal + eta * diff_parallel - return normalized_update diff --git a/wan/trajectory_editor/templates/index.html b/wan/trajectory_editor/templates/index.html deleted file mode 100644 index 2ac8c78..0000000 --- a/wan/trajectory_editor/templates/index.html +++ /dev/null @@ -1,571 +0,0 @@ - - - - - - - Track Point Editor - - - -

Track Point Editor

- - -
- - -
- - - - - -
-
- - - -
-
- - - -
-
- - -
-
- -
-
- - - -
- -
- - - - -
-
- - - - - diff --git a/wgp.py b/wgp.py index bdb7327..8cca2ad 100644 --- a/wgp.py +++ b/wgp.py @@ -13,14 +13,15 @@ from datetime import datetime import gradio as gr import random import json -import wan -from wan.utils import notification_sound -from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.utils import expand_slist, update_loras_slists -from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video -from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files - -from wan.modules.attention import get_attention_modes, get_supported_attention_modes +import numpy as np +import importlib +from shared.utils import notification_sound +from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers +from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame +from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image +from shared.utils.audio_video import save_image_metadata, read_image_metadata +from shared.match_archi import match_nvidia_architecture +from shared.attention import get_attention_modes, get_supported_attention_modes from huggingface_hub import hf_hub_download, snapshot_download import torch import gc @@ -29,7 +30,7 @@ import math import typing import asyncio import inspect -from wan.utils import prompt_parser +from shared.utils import prompt_parser import base64 import io from PIL import Image @@ -45,15 +46,18 @@ from preprocessing.matanyone import app as matanyone_app from tqdm import tqdm import requests +# import torch._dynamo as dynamo +# dynamo.config.recompile_limit = 2000 # default is 256 +# dynamo.config.accumulated_recompile_limit = 2000 # or whatever limit you want global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.1" -WanGP_version = "7.0" -settings_version = 2.22 -max_source_video_frames = 1000 +target_mmgp_version = "3.5.10" +WanGP_version = "8.2" +settings_version = 2.27 +max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None from importlib.metadata import version @@ -67,8 +71,30 @@ task_id = 0 vmc_event_handler = matanyone_app.get_vmc_event_handler() unique_id = 0 unique_id_lock = threading.Lock() -offloadobj = None -wan_model = None +gen_lock = threading.Lock() +offloadobj = enhancer_offloadobj = wan_model = None +reload_needed = True + +def clear_gen_cache(): + if "_cache" in offload.shared_state: + del offload.shared_state["_cache"] + +def release_model(): + global wan_model, offloadobj, reload_needed + clear_gen_cache() + offload.shared_state + if offloadobj is not None: + offloadobj.release() + offloadobj = None + torch.cuda.empty_cache() + gc.collect() + try: + torch._C._host_emptyCache() + except: + pass + reload_needed = True + else: + gc.collect() def get_unique_id(): global unique_id @@ -100,22 +126,25 @@ def download_ffmpeg(): os.rename(f, os.path.basename(f)) os.remove(zip_name) + def format_time(seconds): - if seconds < 60: - return f"{seconds:.1f}s" - elif seconds < 3600: - minutes = seconds / 60 - return f"{minutes:.1f}m" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + if hours > 0: + return f"{hours}h {minutes:02d}m {secs:02d}s" + elif seconds >= 60: + return f"{minutes}m {secs:02d}s" else: - hours = int(seconds // 3600) - minutes = int((seconds % 3600) // 60) - return f"{hours}h {minutes}m" + return f"{seconds:.1f}s" + def pil_to_base64_uri(pil_image, format="png", quality=75): if pil_image is None: return None if isinstance(pil_image, str): - from wan.utils.utils import get_video_frame + from shared.utils.utils import get_video_frame pil_image = get_video_frame(pil_image, 0) buffer = io.BytesIO() @@ -157,7 +186,6 @@ def process_prompt_and_add_tasks(state, model_choice): return state["validate_success"] = 0 - model_filename = state["model_filename"] model_type = state["model_type"] inputs = get_model_settings(state, model_type) @@ -174,19 +202,21 @@ def process_prompt_and_add_tasks(state, model_choice): queue = gen.get("queue", []) return get_queue_table(queue) model_def = get_model_def(model_type) - image_outputs = model_def.get("image_outputs", False) + model_handler = get_model_handler(model_type) + image_outputs = inputs["image_mode"] == 1 + any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename mode = inputs["mode"] - if mode == "edit": + if mode.startswith("edit_"): edit_video_source =gen.get("edit_video_source", None) edit_overrides =gen.get("edit_overrides", None) _ , _ , _, frames_count = get_video_info(edit_video_source) if frames_count > max_source_video_frames: gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated") # return - for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask", "image_mask"]: + for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "audio_source" , "video_mask", "image_mask"]: inputs[k] = None inputs.update(edit_overrides) del gen["edit_video_source"], gen["edit_overrides"] @@ -197,7 +227,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"] temporal_upsampling = inputs.get("temporal_upsampling","") if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] - if image_outputs and len(temporal_upsampling) > 0: + if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0: gr.Info("Temporal Upsampling can not be used with an Image") return film_grain_intensity = inputs.get("film_grain_intensity",0) @@ -205,14 +235,26 @@ def process_prompt_and_add_tasks(state, model_choice): # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] if film_grain_intensity >0: prompt += ["Film Grain"] MMAudio_setting = inputs.get("MMAudio_setting",0) - seed = inputs.get("seed",None) repeat_generation= inputs.get("repeat_generation",1) - if repeat_generation > 1 and (MMAudio_setting == 0 or seed != -1): - gr.Info("It is useless to generate more than one sample if you don't use MMAudio with a random seed") - return - if MMAudio_setting !=0: prompt += ["MMAudio"] + if mode =="edit_remux": + audio_source = inputs["audio_source"] + if MMAudio_setting== 1: + prompt += ["MMAudio"] + audio_source = None + inputs["audio_source"] = audio_source + else: + if audio_source is None: + gr.Info("You must provide a custom Audio") + return + prompt += ["Custom Audio"] + repeat_generation == 1 + + seed = inputs.get("seed",None) if len(prompt) == 0: - gr.Info("You must choose at least one Post Processing Method") + if mode=="edit_remux": + gr.Info("You must choose at least one Remux Method") + else: + gr.Info("You must choose at least one Post Processing Method") return inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) @@ -221,6 +263,11 @@ def process_prompt_and_add_tasks(state, model_choice): queue= gen.get("queue", []) return update_queue_data(queue) + if hasattr(model_handler, "validate_generative_settings"): + error = model_handler.validate_generative_settings(model_type, model_def, inputs) + if error is not None and len(error) > 0: + gr.Info(error) + return if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0: gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time") return @@ -257,6 +304,7 @@ def process_prompt_and_add_tasks(state, model_choice): force_fps = inputs["force_fps"] audio_guide = inputs["audio_guide"] audio_guide2 = inputs["audio_guide2"] + audio_source = inputs["audio_source"] video_guide = inputs["video_guide"] image_guide = inputs["image_guide"] video_mask = inputs["video_mask"] @@ -274,17 +322,42 @@ def process_prompt_and_add_tasks(state, model_choice): num_inference_steps= inputs["num_inference_steps"] skip_steps_cache_type= inputs["skip_steps_cache_type"] MMAudio_setting = inputs["MMAudio_setting"] + image_mode = inputs["image_mode"] + switch_threshold = inputs["switch_threshold"] + loras_multipliers = inputs["loras_multipliers"] + activated_loras = inputs["activated_loras"] + guidance_phases= inputs["guidance_phases"] + model_switch_phase = inputs["model_switch_phase"] + switch_threshold = inputs["switch_threshold"] + switch_threshold2 = inputs["switch_threshold2"] + - if skip_steps_cache_type == "mag": - if model_type in ["sky_df_1.3B", "sky_df_14B"]: - gr.Info("Mag Cache is not supported with Diffusion Forcing") + if len(loras_multipliers) > 0: + _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) + if len(errors) > 0: + gr.Info(f"Error parsing Loras Multipliers: {errors}") return + if guidance_phases == 3: + if switch_threshold < switch_threshold2: + gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.") + return + else: + model_switch_phase = 1 + + if not any_steps_skipping: skip_steps_cache_type = "" + if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: + gr.Info("The minimum number of steps should be 20") + return + if skip_steps_cache_type == "mag": if num_inference_steps > 50: gr.Info("Mag Cache maximum number of steps is 50") return + + if image_mode == 1: + audio_prompt_type = "" if "B" in audio_prompt_type or "X" in audio_prompt_type: - from wan.multitalk.multitalk import parse_speakers_locations + from models.wan.multitalk.multitalk import parse_speakers_locations speakers_bboxes, error = parse_speakers_locations(speakers_locations) if len(error) > 0: gr.Info(error) @@ -306,6 +379,9 @@ def process_prompt_and_add_tasks(state, model_choice): else: frames_positions = None + if audio_source is not None and MMAudio_setting != 0: + gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time") + return if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: gr.Info("The number of frames to keep must be a non null integer") @@ -338,20 +414,16 @@ def process_prompt_and_add_tasks(state, model_choice): if not "I" in video_prompt_type and not not "V" in video_prompt_type: gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side") - if len(filter_letters(image_prompt_type, "VL")) > 0 : - if "R" in audio_prompt_type: - gr.Info("Remuxing is not yet supported if there is a video source") - audio_prompt_type= audio_prompt_type.replace("R" ,"") - if "A" in audio_prompt_type: - gr.Info("Creating an Audio track is not yet supported if there is a video source") - return - - if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: + if model_def.get("one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide an Image Reference") return if len(image_refs) > 1: - gr.Info("Only one Image Reference (a person) is supported for the moment by Hunyuan Custom / Avatar") + gr.Info("Only one Image Reference (a person) is supported for the moment by this model") + return + if model_def.get("at_least_one_image_ref_needed", False): + if image_refs == None : + gr.Info("You must provide at least one Image Reference") return if "I" in video_prompt_type: @@ -367,19 +439,23 @@ def process_prompt_and_add_tasks(state, model_choice): image_refs = None if "V" in video_prompt_type: - if video_guide is None and image_guide is None: - if image_outputs: + if image_outputs: + if image_guide is None: gr.Info("You must provide a Control Image") - else: - gr.Info("You must provide a Control Video") - return - if "A" in video_prompt_type and not "U" in video_prompt_type: - if video_mask is None and image_mask is None: - if image_outputs: - gr.Info("You must provide a Image Mask") - else: - gr.Info("You must provide a Video Mask") return + else: + if video_guide is None: + gr.Info("You must provide a Control Video") + return + if "A" in video_prompt_type and not "U" in video_prompt_type: + if image_outputs: + if image_mask is None: + gr.Info("You must provide a Image Mask") + return + else: + if video_mask is None: + gr.Info("You must provide a Video Mask") + return else: video_mask = None image_mask = None @@ -388,7 +464,9 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") else: denoising_strength = 1.0 - + if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: + gr.Info("Keep Frames for Control Video is not supported with LTX Video") + return _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) if len(error) > 0: gr.Info(f"Invalid Keep Frames property: {error}") @@ -401,6 +479,13 @@ def process_prompt_and_add_tasks(state, model_choice): keep_frames_video_guide = "" denoising_strength = 1.0 + if image_outputs: + video_guide = None + video_mask = None + else: + image_guide = None + image_mask = None + if "S" in image_prompt_type: if image_start == None or isinstance(image_start, list) and len(image_start) == 0: @@ -432,7 +517,7 @@ def process_prompt_and_add_tasks(state, model_choice): image_end = None - if test_any_sliding_window(model_type): + if test_any_sliding_window(model_type) and image_mode == 0: if video_length > sliding_window_size: full_video_length = video_length if video_source is None else video_length + sliding_window_overlap extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" @@ -471,6 +556,7 @@ def process_prompt_and_add_tasks(state, model_choice): "image_refs": image_refs, "audio_guide": audio_guide, "audio_guide2": audio_guide2, + "audio_source": audio_source, "video_guide": video_guide, "image_guide": image_guide, "video_mask": video_mask, @@ -482,7 +568,9 @@ def process_prompt_and_add_tasks(state, model_choice): "denoising_strength": denoising_strength, "image_prompt_type": image_prompt_type, "video_prompt_type": video_prompt_type, - "audio_prompt_type": audio_prompt_type, + "audio_prompt_type": audio_prompt_type, + "skip_steps_cache_type": skip_steps_cache_type, + "model_switch_phase": model_switch_phase, } if inputs["multi_prompts_gen_type"] == 0: @@ -592,7 +680,7 @@ def add_video_task(**inputs): "id": current_task_id, "params": inputs.copy(), "repeats": inputs["repeat_generation"], - "length": inputs["video_length"], + "length": inputs["video_length"], # !!! "steps": inputs["num_inference_steps"], "prompt": inputs["prompt"], "start_image_labels": start_image_labels, @@ -622,9 +710,12 @@ def move_up(queue, selected_indices): idx = idx[0] idx = int(idx) with lock: - if idx > 0: - idx += 1 + idx += 1 + if idx > 1: queue[idx], queue[idx-1] = queue[idx-1], queue[idx] + elif idx == 1: + queue[:] = queue[0:1] + queue[2:] + queue[1:2] + return update_queue_data(queue) def move_down(queue, selected_indices): @@ -638,6 +729,9 @@ def move_down(queue, selected_indices): idx += 1 if idx < len(queue)-1: queue[idx], queue[idx+1] = queue[idx+1], queue[idx] + elif idx == len(queue)-1: + queue[:] = queue[0:1] + queue[-1:] + queue[1:-1] + return update_queue_data(queue) def remove_task(queue, selected_indices): @@ -680,7 +774,7 @@ def save_queue_action(state): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source"] for key in image_keys: images_pil = params_copy.get(key) @@ -856,7 +950,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): params['state'] = state image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source"] loaded_pil_images = {} loaded_video_paths = {} @@ -1050,8 +1144,10 @@ def show_countdown_info_from_state(current_value: int): gr.Info(f"Quitting in {current_value}...") return current_value - 1 return current_value - +quitting_app = False def autosave_queue(): + global quitting_app + quitting_app = True global global_queue_ref if not global_queue_ref: print("Autosave: Queue is empty, nothing to save.") @@ -1076,7 +1172,7 @@ def autosave_queue(): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source" ] for key in image_keys: images_pil = params_copy.get(key) @@ -1335,6 +1431,12 @@ def _parse_args(): help="Path to a directory that contains flux images Loras" ) + parser.add_argument( + "--lora-dir-qwen", + type=str, + default="loras_qwen", + help="Path to a directory that contains qwen images Loras" + ) parser.add_argument( "--check-loras", @@ -1444,7 +1546,7 @@ def _parse_args(): "--perc-reserved-mem-max", type=float, default=0, - help="% of RAM allocated to Reserved RAM" + help="percent of RAM allocated to Reserved RAM" ) @@ -1554,7 +1656,8 @@ def _parse_args(): def get_lora_dir(model_type): model_family = get_model_family(model_type) - i2v = test_class_i2v(model_type) + base_model_type = get_base_model_type(model_type) + i2v = test_class_i2v(model_type) and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] if model_family == "wan": lora_dir =args.lora_dir if i2v and len(lora_dir)==0: @@ -1567,6 +1670,10 @@ def get_lora_dir(model_type): lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B") if os.path.isdir(lora_dir_1_3B ): return lora_dir_1_3B + elif base_model_type == "ti2v_2_2": + lora_dir_5B = os.path.join(root_lora_dir, "5B") + if os.path.isdir(lora_dir_5B ): + return lora_dir_5B else: lora_dir_14B = os.path.join(root_lora_dir, "14B") if os.path.isdir(lora_dir_14B ): @@ -1581,6 +1688,8 @@ def get_lora_dir(model_type): return args.lora_dir_hunyuan_i2v else: return args.lora_dir_hunyuan + elif model_family =="qwen": + return args.lora_dir_qwen else: raise Exception("loras unknown") @@ -1588,8 +1697,8 @@ attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() -major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) -if major < 8: +gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) +if gpu_major < 8: print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels") bfloat16_supported = False else: @@ -1640,10 +1749,10 @@ if not Path(server_config_filename).is_file(): "transformer_types": [], "transformer_quantization": "int8", "text_encoder_quantization" : "int8", - "save_path": "outputs", #os.path.join(os.getcwd(), + "save_path": "outputs", + "image_save_path": "outputs", "compile" : "", "metadata_type": "metadata", - "default_ui": "t2v", "boost" : 1, "clear_file_list" : 5, "vae_config": 0, @@ -1664,27 +1773,23 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors", "wan2.1_text2video_14B_bf16.safetensors", "wan2.1_text2video_14B_quanto_int8.safetensors", -"wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors", "wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "wan2.1_FLF2V_720p_14B_bf16.safetensors" +"wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors", "wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "wan2.1_FLF2V_720p_14B_bf16.safetensors", "wan2.1_FLF2V_720p_14B_fp16.safetensors", "wan2.1_Vace_1.3B_mbf16.safetensors", "wan2.1_text2video_1.3B_bf16.safetensors", +"ltxv_0.9.7_13B_dev_bf16.safetensors" ]: if Path(os.path.join("ckpts" , path)).is_file(): print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) +for f, s in [("ckpts/Florence2/modeling_florence2.py", 127287)]: + try: + if os.path.isfile(f) and os.path.getsize(f) == s: + print(f"Removing old version of model '{f}'. A new version of this model will be downloaded next time you use it.") + os.remove(f) + except: pass + models_def = {} +family_handlers = ["models.wan.wan_handler", "models.wan.df_handler", "models.hyvideo.hunyuan_handler", "models.ltx_video.ltxv_handler", "models.flux.flux_handler", "models.qwen.qwen_handler"] -modules_files = { - "vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"], - "fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"], - "multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] -} - -# unused -base_types = ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", - "t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", "sky_df_1.3B", "sky_df_14B", - "i2v", "flf2v_720p", "fun_inp_1.3B", "fun_inp", "ltxv_13B", - "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar", - ] # only needed for imported old settings files model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", @@ -1695,36 +1800,50 @@ model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", " "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit", "hunyuan_avatar" : "hunyuan_video_avatar" } + +def map_family_handlers(family_handlers): + base_types_handlers, families_infos, models_eqv_map, models_comp_map = {}, {"unknown": (100, "Unknown")}, {}, {} + for path in family_handlers: + handler = importlib.import_module(path).family_handler + for model_type in handler.query_supported_types(): + if model_type in base_types_handlers: + prev = base_types_handlers[model_type].__name__ + raise Exception(f"Model type {model_type} supported by {prev} and {handler.__name__}") + base_types_handlers[model_type] = handler + families_infos.update(handler.query_family_infos()) + eq_map, comp_map = handler.query_family_maps() + models_eqv_map.update(eq_map); models_comp_map.update(comp_map) + return base_types_handlers, families_infos, models_eqv_map, models_comp_map + +model_types_handlers, families_infos, models_eqv_map, models_comp_map = map_family_handlers(family_handlers) + def get_base_model_type(model_type): model_def = get_model_def(model_type) if model_def == None: - return model_type if model_type in model_types else None + return model_type if model_type in model_types_handlers else None # return model_type else: return model_def["architecture"] +def get_model_handler(model_type): + base_model_type = get_base_model_type(model_type) + if base_model_type is None: + raise Exception(f"Unknown model type {model_type}") + model_handler = model_types_handlers.get(base_model_type, None) + if model_handler is None: + raise Exception(f"No model handler found for base model type {base_model_type}") + return model_handler + def are_model_types_compatible(imported_model_type, current_model_type): imported_base_model_type = get_base_model_type(imported_model_type) curent_base_model_type = get_base_model_type(current_model_type) if imported_base_model_type == curent_base_model_type: return True - eqv_map = { - "flf2v_720p" : "i2v", - "t2v_1.3B" : "t2v", - "sky_df_1.3B" : "sky_df_14B", - } - if imported_base_model_type in eqv_map: - imported_base_model_type = eqv_map[imported_base_model_type] - comp_map = { - "vace_14B" : [ "vace_multitalk_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], - "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], - "fantasy": ["multitalk"], - "sky_df_14B": ["sky_df_1.3B"], - "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], - } - comp_list= comp_map.get(imported_base_model_type, None) + if imported_base_model_type in models_eqv_map: + imported_base_model_type = models_eqv_map[imported_base_model_type] + + comp_list= models_comp_map.get(imported_base_model_type, None) if comp_list == None: return False return curent_base_model_type in comp_list @@ -1740,58 +1859,53 @@ def get_model_type(model_filename): return None # raise Exception("Unknown model:" + model_filename) -def get_model_family(model_type): - model_type = get_base_model_type(model_type) - if model_type == None: +def get_model_family(model_type, for_ui = False): + base_model_type = get_base_model_type(model_type) + if base_model_type is None: return "unknown" - if "hunyuan" in model_type : - return "hunyuan" - elif "ltxv" in model_type: - return "ltxv" - elif "flux" in model_type: - return "flux" - else: - return "wan" + + if for_ui : + model_def = get_model_def(model_type) + model_family = model_def.get("group", None) + if model_family is not None and model_family in families_infos: + return model_family + handler = model_types_handlers.get(base_model_type, None) + if handler is None: + return "unknown" + return handler.query_model_family() -def test_class_i2v(model_type): - model_type = get_base_model_type(model_type) - return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v", "multitalk" ] +def test_class_i2v(model_type): + model_def = get_model_def(model_type) + return model_def.get("i2v_class", False) def test_vace_module(model_type): - model_type = get_base_model_type(model_type) - return model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] + model_def = get_model_def(model_type) + return model_def.get("vace_class", False) def test_any_sliding_window(model_type): - model_type = get_base_model_type(model_type) - return test_vace_module(model_type) or model_type in ["sky_df_1.3B", "sky_df_14B", "ltxv_13B", "multitalk", "t2v", "fantasy"] or test_class_i2v(model_type) + model_def = get_model_def(model_type) + return model_def.get("sliding_window", False) def get_model_min_frames_and_step(model_type): - model_type = get_base_model_type(model_type) - if model_type in ["sky_df_14B"]: - return 17, 20 - elif model_type in ["ltxv_13B"]: - return 17, 8 - elif test_vace_module(model_type): - return 17, 4 - else: - return 5, 4 - + mode_def = get_model_def(model_type) + frames_minimum = mode_def.get("frames_minimum", 5) + frames_steps = mode_def.get("frames_steps", 4) + return frames_minimum, frames_steps + def get_model_fps(model_type): - model_type = get_base_model_type(model_type) - if model_type in ["hunyuan_avatar", "hunyuan_custom_audio", "multitalk", "vace_multitalk_14B"]: - fps = 25 - elif model_type in ["sky_df_14B", "hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: - fps = 24 - elif model_type in ["fantasy"]: - fps = 23 - elif model_type in ["ltxv_13B"]: - fps = 30 - else: - fps = 16 + mode_def = get_model_def(model_type) + fps= mode_def.get("fps", 16) return fps def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): - if force_fps == "control" and video_guide != None: + if force_fps == "auto": + if video_source != None: + fps, _, _, _ = get_video_info(video_source) + elif video_guide != None: + fps, _, _, _ = get_video_info(video_guide) + else: + fps = get_model_fps(base_model_type) + elif force_fps == "control" and video_guide != None: fps, _, _, _ = get_video_info(video_guide) elif force_fps == "source" and video_source != None: fps, _, _, _ = get_video_info(video_source) @@ -1813,15 +1927,24 @@ def get_model_name(model_type, description_container = [""]): def get_model_record(model_name): return f"WanGP v{WanGP_version} by DeepBeepMeep - " + model_name -def get_model_recursive_prop(model_type, prop = "URLs", return_list = False, stack= []): +def get_model_recursive_prop(model_type, prop = "URLs", sub_prop_name = None, return_list = True, stack= []): model_def = models_def.get(model_type, None) if model_def != None: prop_value = model_def.get(prop, None) if prop_value == None: return [] + if sub_prop_name is not None: + if sub_prop_name == "_list": + if not isinstance(prop_value,list) or len(prop_value) != 1: + raise Exception(f"Sub property value for property {prop} of model type {model_type} should be a list of size 1") + prop_value = prop_value[0] + else: + if not isinstance(prop_value,dict) and not sub_prop_name in prop_value: + raise Exception(f"Invalid sub property value {sub_prop_name} for property {prop} of model type {model_type}") + prop_value = prop_value[sub_prop_name] if isinstance(prop_value, str): if len(stack) > 10: raise Exception(f"Circular Reference in Model {prop} dependencies: {stack}") - return get_model_recursive_prop(prop_value, prop = prop, stack = stack + [prop_value] ) + return get_model_recursive_prop(prop_value, prop = prop, sub_prop_name =sub_prop_name, stack = stack + [prop_value] ) else: return prop_value else: @@ -1831,19 +1954,35 @@ def get_model_recursive_prop(model_type, prop = "URLs", return_list = False, st raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]): - if is_module: - choices = modules_files.get(model_type, None) - if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") - else: - model_def = models_def.get(model_type, None) - if model_def == None: return None - URLs = model_def["URLs"] - if isinstance(URLs, str): - if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") - return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs]) +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]): + if module_type is not None: + base_model_type = get_base_model_type(model_type) + # model_type_handler = model_types_handlers[base_model_type] + # modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} + if isinstance(module_type, list): + URLs = module_type else: - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + if "#" not in module_type: + sub_prop_name = "_list" + else: + pos = module_type.rfind("#") + sub_prop_name = module_type[pos+1:] + module_type = module_type[:pos] + URLs = get_model_recursive_prop(module_type, "modules", sub_prop_name =sub_prop_name, return_list= False) + + # choices = modules_files.get(module_type, None) + # if choices == None: raise Exception(f"Invalid Module Id '{module_type}'") + else: + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" + + model_def = models_def.get(model_type, None) + if model_def == None: return "" + URLs = model_def[key_name] + if isinstance(URLs, str): + if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") + return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) + + choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] if len(quantization) == 0: quantization = "bf16" @@ -1888,8 +2027,11 @@ def get_settings_file_name(model_type): return os.path.join(args.settings, model_type + "_settings.json") def fix_settings(model_type, ui_defaults): - video_settings_version = ui_defaults.get("settings_version", 0) - model_type = get_base_model_type(model_type) + if model_type is None: return + + settings_version = ui_defaults.get("settings_version", 0) + model_def = get_model_def(model_type) + base_model_type = get_base_model_type(model_type) prompts = ui_defaults.get("prompts", "") if len(prompts) > 0: @@ -1900,43 +2042,47 @@ def fix_settings(model_type, ui_defaults): image_prompt_type = "S" if image_prompt_type == 0 else "SE" # if model_type == "flf2v_720p" and not "E" in image_prompt_type: # image_prompt_type = "SE" - if video_settings_version <= 2: + if settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type if "lset_name" in ui_defaults: del ui_defaults["lset_name"] - - - if model_type == None: return - audio_prompt_type = ui_defaults.get("audio_prompt_type", None) - if video_settings_version < 2.2: - if not model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "ltxv_13B"]: + if settings_version < 2.2: + if not base_model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "ltxv_13B"]: for p in ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]: if p in ui_defaults: del ui_defaults[p] if audio_prompt_type == None : - if any_audio_track(model_type): + if any_audio_track(base_model_type): audio_prompt_type ="A" ui_defaults["audio_prompt_type"] = audio_prompt_type video_prompt_type = ui_defaults.get("video_prompt_type", "") - if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B", "flux_dev_kontext"]: + any_reference_image = model_def.get("reference_image", False) + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: if not "I" in video_prompt_type: # workaround for settings corruption video_prompt_type += "I" - if model_type in ["hunyuan"]: + if base_model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") + if base_model_type in ["flux"] and settings_version < 2.23: + video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI") - remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 0) - if video_settings_version < 2.22: + remove_background_images_ref = ui_defaults.get("remove_background_images_ref", None) + if settings_version < 2.22: if "I" in video_prompt_type: if remove_background_images_ref == 2: video_prompt_type = video_prompt_type.replace("I", "KI") if remove_background_images_ref != 0: remove_background_images_ref = 1 + if base_model_type in ["hunyuan_avatar"]: + remove_background_images_ref = 0 + if settings_version < 2.26: + if not "K" in video_prompt_type: video_prompt_type = video_prompt_type.replace("I", "KI") + if remove_background_images_ref is not None: ui_defaults["remove_background_images_ref"] = remove_background_images_ref ui_defaults["video_prompt_type"] = video_prompt_type @@ -1957,6 +2103,10 @@ def fix_settings(model_type, ui_defaults): del ui_defaults["tea_cache_start_step_perc"] ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc + model_handler = get_model_handler(base_model_type) + if hasattr(model_handler, "fix_settings"): + model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: @@ -1966,17 +2116,19 @@ def get_default_settings(model_type): i2v = test_class_i2v(model_type) defaults_filename = get_settings_file_name(model_type) if not Path(defaults_filename).is_file(): + model_def = get_model_def(model_type) + base_model_type = get_base_model_type(model_type) ui_defaults = { + "settings_version" : settings_version, "prompt": get_default_prompt(i2v), - "resolution": "1280x720" if "720" in model_type else "832x480", + "resolution": "1280x720" if "720" in base_model_type else "832x480", "video_length": 81, "num_inference_steps": 30, "seed": -1, "repeat_generation": 1, "multi_images_gen_type": 0, "guidance_scale": 5.0, - "embedded_guidance_scale" : 6.0, - "flow_shift": 7.0 if not "720" in model_type and i2v else 5.0, + "flow_shift": 7.0 if not "720" in base_model_type and i2v else 5.0, "negative_prompt": "", "activated_loras": [], "loras_multipliers": "", @@ -1988,87 +2140,11 @@ def get_default_settings(model_type): "slg_start_perc": 10, "slg_end_perc": 90 } - if model_type in ["fantasy"]: - ui_defaults["audio_guidance_scale"] = 5.0 - elif model_type in ["multitalk"]: - ui_defaults.update({ - "guidance_scale": 5.0, - "flow_shift": 7, # 11 for 720p - "audio_guidance_scale": 4, - "sliding_window_discard_last_frames" : 4, - "sample_solver" : "euler", - "adaptive_switch" : 1, - }) + model_handler = get_model_handler(model_type) + model_handler.update_default_settings(base_model_type, model_def, ui_defaults) - elif model_type in ["hunyuan","hunyuan_i2v"]: - ui_defaults.update({ - "guidance_scale": 7.0, - }) - - elif model_type in ["flux_dev_kontext"]: - ui_defaults.update({ - "video_prompt_type": "I", - }) - elif model_type in ["sky_df_1.3B", "sky_df_14B"]: - ui_defaults.update({ - "guidance_scale": 6.0, - "flow_shift": 8, - "sliding_window_discard_last_frames" : 0, - "resolution": "1280x720" if "720" in model_type else "960x544", - "sliding_window_size" : 121 if "720" in model_type else 97, - "RIFLEx_setting": 2, - "guidance_scale": 6, - "flow_shift": 8, - }) - - - elif model_type in ["phantom_1.3B", "phantom_14B"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "remove_background_images_ref": 0, - "video_prompt_type": "I", - # "resolution": "1280x720" - }) - - elif model_type in ["hunyuan_custom"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "resolution": "1280x720", - "video_prompt_type": "I", - }) - elif model_type in ["hunyuan_custom_audio"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "I", - }) - elif model_type in ["hunyuan_custom_edit"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "MVAI", - "sliding_window_size": 129, - }) - elif model_type in ["hunyuan_avatar"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "skip_steps_start_step_perc": 25, - "video_length": 129, - "video_prompt_type": "I", - }) - elif model_type in ["vace_14B", "vace_multitalk_14B"]: - ui_defaults.update({ - "sliding_window_discard_last_frames": 0, - }) - - - model_def = get_model_def(model_type) - if model_def != None: - ui_defaults_update = model_def["settings"] - ui_defaults.update(ui_defaults_update) + ui_defaults_update = model_def.get("settings", None) + if ui_defaults_update is not None: ui_defaults.update(ui_defaults_update) if len(ui_defaults.get("prompt","")) == 0: ui_defaults["prompt"]= get_default_prompt(i2v) @@ -2092,9 +2168,15 @@ def get_default_settings(model_type): return ui_defaults -def set_default_model_def(model_def, model_type): - if model_type == "flux_dev_kontext": - model_def.update({"image_outputs": True}) +def init_model_def(model_type, model_def): + base_model_type = get_base_model_type(model_type) + family_handler = model_types_handlers.get(base_model_type, None) + if family_handler is None: + raise Exception(f"Unknown model type {model_type}") + default_model_def = family_handler.query_model_def(base_model_type, model_def) + if default_model_def is None: return model_def + default_model_def.update(model_def) + return default_model_def models_def_paths = glob.glob( os.path.join("defaults", "*.json") ) + glob.glob( os.path.join("finetunes", "*.json") ) @@ -2117,8 +2199,9 @@ for file_path in models_def_paths: existing_settings.update(settings) existing_model_def.update(model_def) else: - models_def[model_type] = model_def - set_default_model_def(model_def, model_type) + models_def[model_type] = model_def # partial def + model_def= init_model_def(model_type, model_def) + models_def[model_type] = model_def # replace with full def model_def["settings"] = settings model_types = models_def.keys() @@ -2139,6 +2222,7 @@ for model_type in transformer_types: transformer_types = new_transformer_types transformer_type = server_config.get("last_model_type", None) advanced = server_config.get("last_advanced_choice", False) +last_resolution = server_config.get("last_resolution_choice", None) if args.advanced: advanced = True if transformer_type != None and not transformer_type in model_types and not transformer_type in models_def: transformer_type = None @@ -2161,7 +2245,8 @@ if len(args.attention)> 0: else: raise Exception(f"Unknown attention mode '{args.attention}'") -profile = force_profile_no if force_profile_no >=0 else server_config["profile"] +default_profile = force_profile_no if force_profile_no >=0 else server_config["profile"] +loaded_profile = -1 compile = server_config.get("compile", "") boost = server_config.get("boost", 1) vae_config = server_config.get("vae_config", 0) @@ -2169,8 +2254,11 @@ if len(args.vae_config) > 0: vae_config = int(args.vae_config) reload_needed = False -default_ui = server_config.get("default_ui", "t2v") -save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) +save_path = server_config.get("save_path", os.path.join(os.getcwd(), "outputs")) +image_save_path = server_config.get("image_save_path", os.path.join(os.getcwd(), "outputs")) +if not "video_output_codec" in server_config: server_config["video_output_codec"]= "libx264_8" +if not "image_output_codec" in server_config: server_config["image_output_codec"]= "jpeg_95" + preload_model_policy = server_config.get("preload_model_policy", []) @@ -2197,11 +2285,83 @@ if args.compile: #args.fastest or compile="transformer" lock_ui_compile = True -def save_quantized_model(model, model_type, model_filename, dtype, config_file): + +def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True ): + model_def = get_model_def(model_type) + # To save module and quantized modules + # 1) set Transformer Model Quantization Type to 16 bits + # 2) insert in def module_source : path and "model_fp16.safetensors in URLs" + # 3) Generate (only quantized fp16 will be created) + # 4) replace in def module_source : path and "model_bf16.safetensors in URLs" + # 5) Generate (both bf16 and quantized bf16 will be created) + if model_def == None: return + if is_module: + url_key = "modules" + source_key = "module_source" + else: + url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) + source_key = "source" + URLs= model_def.get(url_key, None) + if URLs is None: return + if isinstance(URLs, str): + print("Unable to save model for a finetune that references external files") + return + from mmgp import offload + dtypestr= "bf16" if dtype == torch.bfloat16 else "fp16" + if no_fp16_main_model: dtypestr = dtypestr.replace("fp16", "bf16") + model_filename = None + if is_module: + if not isinstance(URLs,list) or len(URLs) != 1: + print("Target Module files are missing") + return + URLs= URLs[0] + for url in URLs: + if "quanto" not in url and dtypestr in url: + model_filename = os.path.basename(url) + break + if model_filename is None: + print(f"No target filename with bf16 or fp16 in its name is mentioned in {url_key}") + return + + finetune_file = os.path.join(os.path.dirname(model_def["path"]) , model_type + ".json") + with open(finetune_file, 'r', encoding='utf-8') as reader: + saved_finetune_def = json.load(reader) + + update_model_def = False + model_filename = os.path.join("ckpts",model_filename) + quanto_dtypestr= "bf16" if dtype == torch.bfloat16 else "fp16" + if ("m" + dtypestr) in model_filename: + dtypestr = "m" + dtypestr + quanto_dtypestr = "m" + quanto_dtypestr + if not os.path.isfile(model_filename) and (not no_fp16_main_model or dtype == torch.bfloat16): + offload.save_model(model, model_filename, config_file_path=config_file, filter_sd=filter) + print(f"New model file '{model_filename}' had been created for finetune Id '{model_type}'.") + del saved_finetune_def["model"][source_key] + del model_def[source_key] + print(f"The 'source' entry has been removed in the '{finetune_file}' definition file.") + update_model_def = True + + if is_module: + quanto_filename = model_filename.replace(dtypestr, "quanto_" + quanto_dtypestr + "_int8" ) + if hasattr(model, "_quanto_map"): + print("unable to generate quantized module, the main model should at full 16 bits before quantization can be done") + elif not os.path.isfile(quanto_filename): + offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter) + print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.") + model_def[url_key][0].append(quanto_filename) + saved_finetune_def["model"][url_key][0].append(quanto_filename) + update_model_def = True + if update_model_def: + with open(finetune_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(saved_finetune_def, indent=4)) + +def save_quantized_model(model, model_type, model_filename, dtype, config_file, submodel_no = 1): if "quanto" in model_filename: return model_def = get_model_def(model_type) if model_def == None: return - URLs= model_def["URLs"] + url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) + URLs= model_def.get(url_key, None) + if URLs is None: return if isinstance(URLs, str): print("Unable to create a quantized model for a finetune that references external files") return @@ -2229,7 +2389,7 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file) finetune_file = os.path.join(os.path.dirname(model_def["path"]) , model_type + ".json") with open(finetune_file, 'r', encoding='utf-8') as reader: saved_finetune_def = json.load(reader) - saved_finetune_def["model"]["URLs"] = URLs + saved_finetune_def["model"][url_key] = URLs with open(finetune_file, "w", encoding="utf-8") as writer: writer.write(json.dumps(saved_finetune_def, indent=4)) print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") @@ -2245,27 +2405,6 @@ def get_loras_preprocessor(transformer, model_type): return preprocessor_wrapper -def get_wan_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" - if text_encoder_quantization =="int8": - text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8") - return text_encoder_filename - -def get_ltxv_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" - if text_encoder_quantization =="int8": - text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") - return text_encoder_filename - -def get_hunyuan_text_encoder_filename(text_encoder_quantization): - if text_encoder_quantization =="int8": - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" - else: - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" - - return text_encoder_filename - - def process_files_def(repoId, sourceFolderList, fileList): targetRoot = "ckpts/" for sourceFolder, files in zip(sourceFolderList,fileList ): @@ -2290,7 +2429,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename, model_type): +def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2301,16 +2440,16 @@ def download_models(model_filename, model_type): from urllib.request import urlretrieve - from wan.utils.utils import create_progress_hook + from shared.utils.utils import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "" ], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], - ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ] + ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2323,51 +2462,75 @@ def download_models(model_filename, model_type): } process_files_def(**enhancer_def) + elif server_config.get("enhancer_enabled", 0) == 2: + enhancer_def = { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : [ "Florence2", "llama-joycaption-beta-one-hf-llava" ], + "fileList" : [ ["config.json", "configuration_florence2.py", "model.safetensors", "modeling_florence2.py", "preprocessor_config.json", "processing_florence2.py", "tokenizer.json", "tokenizer_config.json"],["config.json", "llama_joycaption_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] ] + } + process_files_def(**enhancer_def) + download_mmaudio() + if model_filename is None: return def download_file(url,filename): if url.startswith("https://huggingface.co/") and "/resolve/main/" in url: + base_dir = os.path.dirname(filename) url = url[len("https://huggingface.co/"):] url_parts = url.split("/resolve/main/") repoId = url_parts[0] onefile = os.path.basename(url_parts[-1]) sourceFolder = os.path.dirname(url_parts[-1]) if len(sourceFolder) == 0: - hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/") + hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir) else: target_path = "ckpts/temp/" + sourceFolder if not os.path.exists(target_path): os.makedirs(target_path) hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder) - shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/") + shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir) shutil.rmtree("ckpts/temp") else: urlretrieve(url,filename, create_progress_hook(filename)) - model_family = get_model_family(model_type) + base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - if model_def != None and not model_type in modules_files: - if not os.path.isfile(model_filename ): - URLs = get_model_recursive_prop(model_type, "URLs") - if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after - use_url = model_filename - for url in URLs: - if os.path.basename(model_filename) in url: - use_url = url - break - if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") - try: - download_file(use_url, model_filename) - except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") - model_filename = None + + source = model_def.get("source", None) + module_source = model_def.get("module_source", None) + model_type_handler = model_types_handlers[base_model_type] + + if source is not None and module_type is None or module_source is not None and module_type is not None: + model_filename = None + else: + if not os.path.isfile(model_filename): + if module_type is not None: + key_name = "modules" + URLs = module_type + if isinstance(module_type, str): + URLs = get_model_recursive_prop(module_type, key_name, sub_prop_name="_list", return_list= False) + else: + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" + URLs = get_model_recursive_prop(model_type, key_name, return_list= False) + if isinstance(URLs, str): + raise Exception("Missing model " + URLs) + use_url = model_filename + for url in URLs: + if os.path.basename(model_filename) in url: + use_url = url + break + if not url.startswith("http"): + raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") + try: + download_file(use_url, model_filename) + except Exception as e: + if os.path.isfile(model_filename): os.remove(model_filename) + raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") + + model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) - model_loras = get_model_recursive_prop(model_type, "loras", return_list= True) - - for url in preload_URLs + model_loras: + for url in preload_URLs: filename = "ckpts/" + url.split("/")[-1] if not os.path.isfile(filename ): if not url.startswith("http"): @@ -2377,54 +2540,19 @@ def download_models(model_filename, model_type): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'") - if model_family == "wan": - text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], - "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] - } - elif model_family == "ltxv": - text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/LTX_Video", - "sourceFolderList" : ["T5_xxl_1.1", "" ], - "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] - } - elif model_family == "hunyuan": - text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) - model_files = { - "repoId" : "DeepBeepMeep/HunyuanVideo", - "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], - "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , - ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], - ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], - ["detface.pt"], - [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) - ] - } - elif model_family == "flux": - text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) - model_files = [ - { - "repoId" : "DeepBeepMeep/Flux", - "sourceFolderList" : [""], - "fileList" : [ ["flux_vae.safetensors"] ] - }, - { - "repoId" : "DeepBeepMeep/LTX_Video", - "sourceFolderList" : ["T5_xxl_1.1"], - "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ] - }, - { - "repoId" : "DeepBeepMeep/HunyuanVideo", - "sourceFolderList" : [ "clip_vit_large_patch14", ], - "fileList" :[ - ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], - ] - } - ] - + + model_loras = get_model_recursive_prop(model_type, "loras", return_list= True) + for url in model_loras: + filename = os.path.join(get_lora_dir(model_type), url.split("/")[-1]) + if not os.path.isfile(filename ): + if not url.startswith("http"): + raise Exception(f"Lora '{filename}' was not found in the Loras Folder and no URL was provided to download it. Please add an URL in the model definition file.") + try: + download_file(url, filename) + except Exception as e: + if os.path.isfile(filename): os.remove(filename) + raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'") + model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) if not isinstance(model_files, list): model_files = [model_files] for one_repo in model_files: process_files_def(**one_repo) @@ -2475,7 +2603,6 @@ def extract_preset(model_type, lset_name, loras): return loras_choices, loras_mult_choices, prompt, full_prompt, error - def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] loras_names = [] @@ -2520,127 +2647,97 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl print(error[:200]) return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset +def get_transformer_model(model, submodel_no = 1): + if submodel_no > 1: + model_key = f"model{submodel_no}" + if not hasattr(model, model_key): return None -def load_wan_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): - if test_class_i2v(base_model_type): - cfg = WAN_CONFIGS['i2v-14B'] - else: - cfg = WAN_CONFIGS['t2v-14B'] - # cfg = WAN_CONFIGS['t2v-1.3B'] - if base_model_type in ("sky_df_1.3B", "sky_df_14B"): - model_factory = wan.DTT2V - else: - model_factory = wan.WanAny2V - - wan_model = model_factory( - config=cfg, - checkpoint_dir="ckpts", - model_filename=model_filename, - model_type = model_type, - model_def = model_def, - base_model_type=base_model_type, - text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), - quantizeTransformer = quantizeTransformer, - dtype = dtype, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model } - if hasattr(wan_model, "clip"): - pipe["text_encoder_2"] = wan_model.clip.model - return wan_model, pipe - -def load_ltxv_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from ltx_video.ltxv import LTXV - - ltxv_model = LTXV( - model_filepath = model_filename, - text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), - model_def = model_def, - dtype = dtype, - # quantizeTransformer = quantizeTransformer, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer - ) - - pipeline = ltxv_model.pipeline - pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler} - - return ltxv_model, pipe - - -def load_flux_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from flux.flux_main import model_factory - - flux_model = model_factory( - checkpoint_dir="ckpts", - model_filename=model_filename, - model_type = model_type, - base_model_type=base_model_type, - text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization), - quantizeTransformer = quantizeTransformer, - dtype = dtype, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5} - - return flux_model, pipe - -def load_hunyuan_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): - from hyvideo.hunyuan import HunyuanVideoSampler - - hunyuan_model = HunyuanVideoSampler.from_pretrained( - model_filepath = model_filename, - model_type = model_type, - base_model_type = base_model_type, - text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), - dtype = dtype, - quantizeTransformer = quantizeTransformer, - VAE_dtype = VAE_dtype, - mixed_precision_transformer = mixed_precision_transformer, - save_quantized = save_quantized - ) - - pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } - - if hunyuan_model.wav2vec != None: - pipe["wav2vec"] = hunyuan_model.wav2vec - - - # if hunyuan_model.align_instance != None: - # pipe["align_instance"] = hunyuan_model.align_instance.facedet.model - - - from hyvideo.modules.models import get_linear_split_map - - split_linear_modules_map = get_linear_split_map() - hunyuan_model.model.split_linear_modules_map = split_linear_modules_map - offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map ) - - - return hunyuan_model, pipe - -def get_transformer_model(model): if hasattr(model, "model"): - return model.model + if submodel_no > 1: + return getattr(model, f"model{submodel_no}") + else: + return model.model elif hasattr(model, "transformer"): return model.transformer else: raise Exception("no transformer found") +def init_pipe(pipe, kwargs, override_profile): + preload =int(args.preload) + if preload == 0: + preload = server_config.get("preload_in_VRAM", 0) -def load_models(model_type): - global transformer_type + kwargs["extraModelsToQuantize"]= None + profile = override_profile if override_profile != -1 else default_profile + if profile in (2, 4, 5): + budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } + if "transformer2" in pipe: + budgets["transformer2"] = 100 if preload == 0 else preload + kwargs["budgets"] = budgets + elif profile == 3: + kwargs["budgets"] = { "*" : "70%" } + + if "transformer2" in pipe: + if profile in [3,4]: + kwargs["pinnedMemory"] = ["transformer", "transformer2"] + + return profile + +def reset_prompt_enhancer(): + global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, enhancer_offloadobj + prompt_enhancer_image_caption_model = None + prompt_enhancer_image_caption_processor = None + prompt_enhancer_llm_model = None + prompt_enhancer_llm_tokenizer = None + if enhancer_offloadobj is not None: + enhancer_offloadobj.release() + enhancer_offloadobj = None + +def setup_prompt_enhancer(pipe, kwargs): + global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer + model_no = server_config.get("enhancer_enabled", 0) + if model_no != 0: + from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) + pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model + prompt_enhancer_image_caption_model._model_dtype = torch.float + # def preprocess_sd(sd, map): + # new_sd ={} + # for k, v in sd.items(): + # k = "model." + k.replace(".model.", ".") + # if "lm_head.weight" in k: k = "lm_head.weight" + # new_sd[k] = v + # return new_sd, map + # prompt_enhancer_llm_model = offload.fast_load_transformers_model("c:/temp/joy/model-00001-of-00004.safetensors", modelClass= LlavaForConditionalGeneration, defaultConfigPath="ckpts/llama-joycaption-beta-one-hf-llava/config.json", preprocess_sd=preprocess_sd) + # offload.save_model(prompt_enhancer_llm_model, "joy_llava_quanto_int8.safetensors", do_quantize= True) + + if model_no == 1: + budget = 5000 + prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors") + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") + else: + budget = 10000 + prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/llama-joycaption-beta-one-hf-llava/llama_joycaption_quanto_bf16_int8.safetensors") + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/llama-joycaption-beta-one-hf-llava") + pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model + if not "budgets" in kwargs: kwargs["budgets"] = {} + kwargs["budgets"]["prompt_enhancer_llm_model"] = budget + else: + reset_prompt_enhancer() + + + +def load_models(model_type, override_profile = -1): + global transformer_type, loaded_profile base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - preload =int(args.preload) save_quantized = args.save_quantized and model_def != None model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) + if "URLs2" in model_def: + model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! + else: + model_filename2 = None modules = get_model_recursive_prop(model_type, "modules", return_list= True) if save_quantized and "quanto" in model_filename: save_quantized = False @@ -2658,76 +2755,58 @@ def load_models(model_type): transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype perc_reserved_mem_max = args.perc_reserved_mem_max - if preload == 0: - preload = server_config.get("preload_in_VRAM", 0) model_file_list = [model_filename] model_type_list = [model_type] - new_transformer_filename = model_file_list[-1] + module_type_list = [None] + model_submodel_no_list = [1] + if model_filename2 != None: + model_file_list += [model_filename2] + model_type_list += [model_type] + module_type_list += [None] + model_submodel_no_list += [2] for module_type in modules: - model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True)) - model_type_list.append(module_type) - for filename, file_model_type in zip(model_file_list, model_type_list): - download_models(filename, file_model_type) + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) + model_type_list.append(model_type) + module_type_list.append(module_type) + model_submodel_no_list.append(0) + for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list): + download_models(filename, file_model_type, file_module_type, submodel_no) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_type = None - for i, filename in enumerate(model_file_list): - if i==0: + for submodel_no, filename in zip(model_submodel_no_list, model_file_list): + if submodel_no>=1: print(f"Loading Model '{filename}' ...") - elif "_lora" not in filename: + else: print(f"Loading Module '{filename}' ...") - if model_family == "wan" : - wan_model, pipe = load_wan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "flux": - wan_model, pipe = load_flux_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - elif model_family == "hunyuan": - wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - else: - raise Exception(f"Model '{new_transformer_filename}' not supported.") - wan_model._model_file_name = new_transformer_filename - kwargs = { "extraModelsToQuantize": None } - if profile in (2, 4, 5): - kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } - elif profile == 3: - kwargs["budgets"] = { "*" : "70%" } - - global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer - if server_config.get("enhancer_enabled", 0) == 1: - from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors") #, configKwargs= {"_attn_implementation" :"XXXsdpa"} - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") - pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model - pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model - prompt_enhancer_image_caption_model._model_dtype = torch.float - if "budgets" in kwargs: - kwargs["budgets"]["prompt_enhancer_llm_model"] = 5000 - else: - prompt_enhancer_image_caption_model = None - prompt_enhancer_image_caption_processor = None - prompt_enhancer_llm_model = None - prompt_enhancer_llm_tokenizer = None + wan_model, pipe = model_types_handlers[base_model_type].load_model( + model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, + dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) + kwargs = {} + profile = init_pipe(pipe, kwargs, override_profile) + if server_config.get("enhancer_mode", 0) == 0: + setup_prompt_enhancer(pipe, kwargs) + loras_transformer = ["transformer"] + if "transformer2" in pipe: + loras_transformer += ["transformer2"] + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type - return wan_model, offloadobj, pipe["transformer"] + loaded_profile = profile + return wan_model, offloadobj if not "P" in preload_model_policy: wan_model, offloadobj, transformer = None, None, None reload_needed = True else: - wan_model, offloadobj, transformer = load_models(transformer_type) + wan_model, offloadobj = load_models(transformer_type) if check_loras: + transformer = get_transformer_model(wan_model) setup_loras(transformer_type, transformer, get_lora_dir(transformer_type), "", None) exit() - del transformer gen_in_progress = False @@ -2743,13 +2822,16 @@ def generate_header(model_type, compile, attention_mode): get_model_name(model_type, description_container) model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" description = description_container[0] - header = "
" + description + "
" - - header += "
Attention mode " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) + header = f"
{description}
" + overridden_attention = get_overridden_attention(model_type) + attn_mode = attention_mode if overridden_attention == None else overridden_attention + header += "
Attention mode " + (attn_mode if attn_mode!="auto" else "auto/" + get_auto_attention() ) if attention_mode not in attention_modes_installed: header += " -NOT INSTALLED-" elif attention_mode not in attention_modes_supported: header += " -NOT SUPPORTED-" + elif overridden_attention is not None and attention_mode != overridden_attention: + header += " -MODEL SPECIFIC-" header += "" if compile: @@ -2765,6 +2847,13 @@ def generate_header(model_type, compile, attention_mode): return header +def release_RAM(): + if gen_in_progress: + gr.Info("Unable to release RAM when a Generation is in Progress") + else: + release_model() + gr.Info("Models stored in RAM have been released") + def apply_changes( state, transformer_types_choices, transformer_dtype_policy_choice, @@ -2772,6 +2861,7 @@ def apply_changes( state, VAE_precision_choice, mixed_precision_choice, save_path_choice, + image_save_path_choice, attention_choice, compile_choice, profile_choice, @@ -2783,23 +2873,31 @@ def apply_changes( state, preload_model_policy_choice = 1, UI_theme_choice = "default", enhancer_enabled_choice = 0, + enhancer_mode_choice = 0, mmaudio_enabled_choice = 0, fit_canvas_choice = 0, preload_in_VRAM_choice = 0, depth_anything_v2_variant_choice = "vitl", - notification_sound_enabled_choice = 1, + notification_sound_enabled_choice = 0, notification_sound_volume_choice = 50, + max_frames_multiplier_choice = 1, + display_stats_choice = 0, + video_output_codec_choice = None, + image_output_codec_choice = None, + audio_output_codec_choice = None, + last_resolution_choice = None, ): if args.lock_config: - return + return "
Config Locked
",*[gr.update()]*4 if gen_in_progress: - return "
Unable to change config when a generation is in progress
", gr.update(), gr.update() + return "
Unable to change config when a generation is in progress
",*[gr.update()]*4 global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = { "attention_mode" : attention_choice, "transformer_types": transformer_types_choices, "text_encoder_quantization" : text_encoder_quantization_choice, "save_path" : save_path_choice, + "image_save_path" : image_save_path_choice, "compile" : compile_choice, "profile" : profile_choice, "vae_config" : vae_config_choice, @@ -2814,13 +2912,22 @@ def apply_changes( state, "UI_theme" : UI_theme_choice, "fit_canvas": fit_canvas_choice, "enhancer_enabled" : enhancer_enabled_choice, + "enhancer_mode" : enhancer_mode_choice, "mmaudio_enabled" : mmaudio_enabled_choice, "preload_in_VRAM" : preload_in_VRAM_choice, "depth_anything_v2_variant": depth_anything_v2_variant_choice, "notification_sound_enabled" : notification_sound_enabled_choice, "notification_sound_volume" : notification_sound_volume_choice, + "max_frames_multiplier" : max_frames_multiplier_choice, + "display_stats" : display_stats_choice, + "video_output_codec" : video_output_codec_choice, + "image_output_codec" : image_output_codec_choice, + "audio_output_codec" : audio_output_codec_choice, "last_model_type" : state["model_type"], + "last_model_per_family": state["last_model_per_family"], "last_advanced_choice": state["advanced"], + "last_resolution_choice": last_resolution_choice, + "last_resolution_per_group": state["last_resolution_per_group"], } if Path(server_config_filename).is_file(): @@ -2841,14 +2948,15 @@ def apply_changes( state, if v != v_old: changes.append(k) - global attention_mode, profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path + global attention_mode, default_profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path attention_mode = server_config["attention_mode"] - profile = server_config["profile"] + default_profile = server_config["profile"] compile = server_config["compile"] text_encoder_quantization = server_config["text_encoder_quantization"] vae_config = server_config["vae_config"] boost = server_config["boost"] save_path = server_config["save_path"] + image_save_path = server_config["image_save_path"] preload_model_policy = server_config["preload_model_policy"] transformer_quantization = server_config["transformer_quantization"] transformer_dtype_policy = server_config["transformer_dtype_policy"] @@ -2856,27 +2964,20 @@ def apply_changes( state, transformer_types = server_config["transformer_types"] model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename - if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled"] for change in changes ): + if "enhancer_enabled" in changes or "enhancer_mode" in changes: + reset_prompt_enhancer() + if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", + "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", + "video_output_codec", "image_output_codec", "audio_output_codec"] for change in changes ): + model_family = gr.Dropdown() model_choice = gr.Dropdown() else: reload_needed = True - model_choice = generate_dropdown_model_list(transformer_type) + model_family, model_choice = generate_dropdown_model_list(transformer_type) header = generate_header(state["model_type"], compile=compile, attention_mode= attention_mode) mmaudio_enabled = server_config["mmaudio_enabled"] > 0 - return "
The new configuration has been succesfully applied
", header, model_choice, gr.Row(visible= server_config["enhancer_enabled"] == 1), gr.Row(visible= mmaudio_enabled), gr.Column(visible= mmaudio_enabled) - - - -from moviepy.editor import ImageSequenceClip -import numpy as np - -def save_video(final_frames, output_path, fps=24): - assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)" - if final_frames.dtype != np.uint8: - final_frames = (final_frames * 255).astype(np.uint8) - ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False) - + return "
The new configuration has been succesfully applied
", header, model_family, model_choice, get_unique_id() def get_gen_info(state): cache = state.get("gen", None) @@ -2888,7 +2989,26 @@ def get_gen_info(state): def build_callback(state, pipe, send_cmd, status, num_inference_steps): gen = get_gen_info(state) gen["num_inference_steps"] = num_inference_steps - def callback(step_idx, latent, force_refresh, read_state = False, override_num_inference_steps = -1, pass_no = -1): + start_time = time.time() + def callback(step_idx = -1, latent = None, force_refresh = True, read_state = False, override_num_inference_steps = -1, pass_no = -1, denoising_extra =""): + in_pause = False + with gen_lock: + process_status = gen.get("process_status", None) + pause_msg = None + if process_status.startswith("request:"): + gen["process_status"] = "process:" + process_status[len("request:"):] + offloadobj.unload_all() + pause_msg = gen.get("pause_msg", "Unknown Pause") + in_pause = True + + if in_pause: + send_cmd("progress", [0, pause_msg]) + while True: + time.sleep(1) + with gen_lock: + process_status = gen.get("process_status", None) + if process_status == "process:main": break + refresh_id = gen.get("refresh", -1) if force_refresh or step_idx >= 0: pass @@ -2925,9 +3045,14 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps): phase = "Denoising Third Pass" else: phase = f"Denoising {pass_no}th Pass" - + + if len(denoising_extra) > 0: phase += " | " + denoising_extra + gen["progress_phase"] = (phase, step_idx) status_msg = merge_status_context(status, phase) + + elapsed_time = time.time() - start_time + status_msg = merge_status_context(status, f"{phase} | {format_time(elapsed_time)}") if step_idx >= 0: progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps] else: @@ -2963,10 +3088,10 @@ def refresh_gallery(state): #, msg # gen["last_msg"] = msg file_list = gen.get("file_list", None) choice = gen.get("selected",0) + header_text = gen.get("header_text", "") in_progress = "in_progress" in gen - if in_progress: - if gen.get("last_selected", True): - choice = max(len(file_list) - 1,0) + if gen.get("last_selected", True) and file_list is not None: + choice = max(len(file_list) - 1,0) queue = gen.get("queue", []) abort_interactive = not gen.get("abort", False) @@ -2974,17 +3099,14 @@ def refresh_gallery(state): #, msg return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= False) else: task = queue[0] - start_img_md = "" - end_img_md = "" prompt = task["prompt"] params = task["params"] model_type = params["model_type"] base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - is_image = model_def.get("image_outputs", False) - onemorewindow_visible = test_any_sliding_window(base_model_type) and not is_image + onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0 and not params.get("mode","").startswith("edit_") enhanced = False - if prompt.startswith("!enhanced!\n"): + if prompt.startswith("!enhanced!\n"): enhanced = True prompt = prompt[len("!enhanced!\n"):] if "\n" in prompt : @@ -2997,6 +3119,9 @@ def refresh_gallery(state): #, msg prompt = "
".join(prompts) if enhanced: prompt = "Enhanced:
" + prompt + + if len(header_text) > 0: + prompt = "" + header_text + "

" + prompt list_uri = [] list_labels = [] start_img_uri = task.get('start_image_data_base64') @@ -3083,9 +3208,12 @@ def select_video(state, input_file_list, event_data: gr.EventData): gen = get_gen_info(state) file_list, file_settings_list = get_file_list(state, input_file_list) - if data!=None: + if data!=None and isinstance(data, dict): choice = data.get("index",0) - set_file_choice(gen, file_list, choice) + else: + choice = min(len(file_list)-1, gen.get("selected",0)) if len(file_list) > 0 else -1 + set_file_choice(gen, file_list, choice) + if len(file_list) > 0: configs = file_settings_list[choice] @@ -3100,9 +3228,9 @@ def select_video(state, input_file_list, event_data: gr.EventData): if not has_video_file_extension(file_name): img = Image.open(file_name) width, height = img.size - configs = None is_image = True - nb_audio_tracks = 0 + frames_count = fps = 1 + nb_audio_tracks = 0 else: fps, width, height, frames_count = get_video_info(file_name) is_image = False @@ -3155,7 +3283,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): values +=[video_creation_date] labels +=["Creation Date"] else: - video_prompt = configs.get("prompt", "")[:200] + video_prompt = configs.get("prompt", "")[:1024] video_video_prompt_type = configs.get("video_prompt_type", "") video_image_prompt_type = configs.get("image_prompt_type", "") video_audio_prompt_type = configs.get("audio_prompt_type", "") @@ -3171,6 +3299,9 @@ def select_video(state, input_file_list, event_data: gr.EventData): + [ v for s,v in map_video_prompt.items() if check(video_video_prompt_type,s)] \ + [ v for s,v in map_audio_prompt.items() if all_letters(video_audio_prompt_type,s)] video_model_type = configs.get("model_type", "t2v") + model_family = get_model_family(video_model_type) + model_def = get_model_def(video_model_type) + multiple_submodels = model_def.get("multiple_submodels", False) video_other_prompts = ", ".join(video_other_prompts) video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" video_length = configs.get("video_length", 0) @@ -3178,22 +3309,44 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_length_summary = f"{video_length} frames" video_window_no = configs.get("window_no", 0) if video_window_no > 0: video_length_summary +=f", Window no {video_window_no }" - video_length_summary += " (" - if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, " - video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" - video_guidance_scale = configs.get("guidance_scale", 1) - video_NAG_scale = configs.get("NAG_scale", 1) - video_embedded_guidance_scale = configs.get("video_embedded_guidance_scale ", 1) - if get_model_family(video_model_type) == "hunyuan": + if is_image: + video_length_summary = configs.get("batch_size", 1) + video_length_label = "Number of Images" + else: + video_length_summary += " (" + video_length_label = "Video Length" + if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, " + video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" + video_guidance_scale = configs.get("guidance_scale", None) + video_guidance2_scale = configs.get("guidance2_scale", None) + video_guidance3_scale = configs.get("guidance3_scale", None) + video_audio_guidance_scale = configs.get("audio_guidance_scale", None) + video_switch_threshold = configs.get("switch_threshold", 0) + video_switch_threshold2 = configs.get("switch_threshold2", 0) + video_model_switch_phase = configs.get("model_switch_phase", 1) + video_guidance_phases = configs.get("guidance_phases", 0) + video_embedded_guidance_scale = configs.get("embedded_guidance_scale", None) + video_guidance_label = "Guidance" + if model_def.get("embedded_guidance", False): video_guidance_scale = video_embedded_guidance_scale video_guidance_label = "Embedded Guidance Scale" - else: - video_guidance_label = "Guidance" - video_flow_shift = configs.get("flow_shift", 1) + elif video_guidance_phases > 0: + if video_guidance_phases == 1: + video_guidance_scale = f"{video_guidance_scale}" + elif video_guidance_phases == 2: + if multiple_submodels: + video_guidance_scale = f"{video_guidance_scale} (High Noise), {video_guidance2_scale} (Low Noise) with Switch at Noise Level {video_switch_threshold}" + else: + video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale} with Guidance Switch at Noise Level {video_switch_threshold}" + else: + video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale} & {video_guidance3_scale} with Switch at Noise Levels {video_switch_threshold} & {video_switch_threshold2}" + if multiple_submodels: + video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}" + video_flow_shift = configs.get("flow_shift", None) video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ - and (any_letters(video_video_prompt_type, "VFK") or any_letters(video_image_prompt_type, "VL")) : + and (any_letters(video_video_prompt_type, "VFK") ) : video_video_guide_outpainting = video_video_guide_outpainting.split(" ") video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%" video_num_inference_steps = configs.get("num_inference_steps", 0) @@ -3211,19 +3364,28 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 : + if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): values += [video_outpainting] - labels += ["Outpainting"] - values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] - labels += [ "Resolution", "Video Length", "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"] + labels += ["Outpainting"] + video_sample_solver = configs.get("sample_solver", "") + if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 : + values += [video_sample_solver] + labels += ["Sampler Solver"] + values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_audio_guidance_scale, video_flow_shift, video_num_inference_steps] + labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Audio Guidance Scale", "Shift Scale", "Num Inference steps"] video_negative_prompt = configs.get("negative_prompt", "") if len(video_negative_prompt) > 0: values += [video_negative_prompt] labels += ["Negative Prompt"] - video_NAG_scale = configs.get("NAG_scale", 1) - if video_NAG_scale > 1: + video_NAG_scale = configs.get("NAG_scale", None) + if video_NAG_scale is not None and video_NAG_scale > 1: values += [video_NAG_scale] - labels += ["NAG Scale"] + labels += ["NAG Scale"] + video_apg_switch = configs.get("apg_switch", None) + if video_apg_switch is not None and video_apg_switch != 0: + values += ["on"] + labels += ["APG"] + video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) @@ -3245,6 +3407,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): labels +=["Nb Audio Tracks"] values += [ video_creation_date, video_generation_time ] labels += [ "Creation Date", "Generation Time" ] + labels = [label for value, label in zip(values, labels) if value is not None] + values = [value for value in values if value is not None] table_style = """