Merge branch 'main' into feature_add-cuda-docker-runner

This commit is contained in:
deepbeepmeep 2025-09-27 12:54:13 +02:00 committed by GitHub
commit b28cb446bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 5894 additions and 3040 deletions

230
README.md
View File

@ -19,222 +19,96 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
### August 29 2025: WanGP v8.2 - Here Goes Your Weekend
## 🔥 Latest Updates :
### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise)
### August 24 2025: WanGP v8.1 - the RAM Liberator
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
Also because I wanted to spoil you:
- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
### August 21 2025: WanGP v8.01 - the killer of seven
- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask.
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
*Update 8.72*: shadow drop of Qwen Edit Plus
*Update 8.73*: Qwen Preview & InfiniteTalk Start image
*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video
This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
### September 15 2025: WanGP v8.6 - Attack of the Clones
Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
Support:
- Video: x264, x264 lossless, x265
- Images: jpeg, png, webp, wbp lossless
Generation Settings are stored in each of the above regardless of the format (that was the hard part).
- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
Also you can now choose different output directories for images and videos.
For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on ....
If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video.
### August 10 2025: WanGP v7.76 - Faster than the VAE ...
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ?
*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof.
### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
Otherwise in the news:
- **Cropped Input Image Prompts**: as quite often most *Image Prompts* provided (*Start Image, Input Video, Reference Image, Control Video, ...*) rarely matched your requested *Output Resolution*. In that case I used the resolution you gave either as a *Pixels Budget* or as an *Outer Canvas* for the Generated Video. However in some occasion you really want the requested Output Resolution and nothing else. Besides some models deliver much better Generations if you stick to one of their supported resolutions. In order to address this need I have added a new Output Resolution choice in the *Configuration Tab*: **Dimensions Correspond to the Ouput Weight & Height as the Prompt Images will be Cropped to fit Exactly these dimensins**. In short if needed the *Input Prompt Images* will be cropped (centered cropped for the moment). You will see this can make quite a difference for some models
Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
- *Qwen Edit* has now a new sub Tab called **Inpainting**, that lets you target with a brush which part of the *Image Prompt* you want to modify. This is quite convenient if you find that Qwen Edit modifies usually too many things. Of course, as there are more constraints for Qwen Edit don't be surprised if sometime it will return the original image unchanged. A piece of advise: describe in your *Text Prompt* where (for instance *left to the man*, *top*, ...) the parts that you want to modify are located.
### August 8 2025: WanGP v7.73 - Qwen Rebirth
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
The mask inpainting is fully compatible with *Matanyone Mask generator*: generate first an *Image Mask* with Matanyone, transfer it to the current Image Generator and modify the mask with the *Paint Brush*. Talking about matanyone I have fixed a bug that caused a mask degradation with long videos (now WanGP Matanyone is as good as the original app and still requires 3 times less VRAM)
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
- This **Inpainting Mask Editor** has been added also to *Vace Image Mode*. Vace is probably still one of best Image Editor today. Here is a very simple & efficient workflow that do marvels with Vace:
Select *Vace Cocktail > Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*.
Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ...
*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*"
**update 8.55**: Flux Festival
- **Inpainting Mode** also added for *Flux Kontext*
- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available
- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768
### August 6 2025: WanGP v7.71 - Picky, picky
Good luck with finding your way through all the Flux models names !
This release comes with two new models :
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
### September 5 2025: WanGP v8.4 - Take me to Outer Space
You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie.
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
I have made it easier to do just that with *Qwen Edit* and *Wan*:
- **End Frames can now be combined with Continue a Video** (and not just a Start Frame)
- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*)
The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement.
### August 4 2025: WanGP v7.6 - Remuxed
This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**.
So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close)
With this new version you won't have any excuse if there is no sound in your video.
*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens).
Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
### September 2 2025: WanGP v8.31 - At last the pain stops
As a result you can apply a different sound source on each new video segment when doing a *Continue Video*.
- This single new feature should give you the strength to face all the potential bugs of this new release:
**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**
For instance:
- first video part: use Multitalk with two people speaking
- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p).
To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ...
Also:
- End Frame support added for LTX Video models
- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
- Flux Krea Dev support
### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2
Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs*
Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
And this time I really removed Vace Cocktail Light which gave a blurry vision.
### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview
Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
### July 27 2025: WanGP v7.3 : Interlude
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
### July 26 2025: WanGP v7.2 : Ode to Vace
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
Here are some new Vace improvements:
- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
### July 21 2025: WanGP v7.12
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
- Easier way to select video resolution
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
Also in the news:
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
- *Film Grain* post processing to add a vintage look at your video
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models.
Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters.
The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence.
### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
But wait, there is more:
- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
Also, of interest too:
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
- In one click use the newly generated video as a Control Video or Source Video to be continued
- Manage multiple settings for the same model and switch between them using a dropdown box
- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open
- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
Taking care of your life is not enough, you want new stuff to play with ?
- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video
- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer
- Choice of Wan Samplers / Schedulers
- More Lora formats support
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
See full changelog: **[Changelog](docs/CHANGELOG.md)**

15
configs/animate.json Normal file
View File

@ -0,0 +1,15 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"motion_encoder_dim": 512
}

14
configs/lucy_edit.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.33.0",
"dim": 3072,
"eps": 1e-06,
"ffn_dim": 14336,
"freq_dim": 256,
"in_dim": 96,
"model_type": "ti2v2_2",
"num_heads": 24,
"num_layers": 30,
"out_dim": 48,
"text_len": 512
}

17
defaults/animate.json Normal file
View File

@ -0,0 +1,17 @@
{
"model": {
"name": "Wan2.2 Animate",
"architecture": "animate",
"description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
],
"preload_URLs" :
[
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors"
],
"group": "wan2_2"
}
}

View File

@ -7,8 +7,6 @@
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
],
"image_outputs": true,
"reference_image": true,
"flux-model": "flux-dev-kontext"
},
"prompt": "add a hat",

View File

@ -0,0 +1,24 @@
{
"model": {
"name": "Flux 1 Dev UMO 12B",
"architecture": "flux",
"description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.",
"URLs": "flux",
"flux-model": "flux-dev-umo",
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"],
"resolutions": [ ["1024x1024 (1:1)", "1024x1024"],
["768x1024 (3:4)", "768x1024"],
["1024x768 (4:3)", "1024x768"],
["512x1024 (1:2)", "512x1024"],
["1024x512 (2:1)", "1024x512"],
["768x768 (1:1)", "768x768"],
["768x512 (3:2)", "768x512"],
["512x768 (2:3)", "512x768"]]
},
"prompt": "the man is wearing a hat",
"embedded_guidance_scale": 4,
"resolution": "768x768",
"batch_size": 1
}

View File

@ -2,15 +2,13 @@
"model": {
"name": "Flux 1 Dev USO 12B",
"architecture": "flux",
"description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).",
"description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).",
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
"URLs": "flux",
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
"image_outputs": true,
"reference_image": true,
"flux-model": "flux-dev-uso"
},
"prompt": "add a hat",
"prompt": "the man is wearing a hat",
"embedded_guidance_scale": 4,
"resolution": "1024x1024",
"batch_size": 1

15
defaults/flux_srpo.json Normal file
View File

@ -0,0 +1,15 @@
{
"model": {
"name": "Flux 1 SRPO Dev 12B",
"architecture": "flux",
"description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors"
],
"flux-model": "flux-dev"
},
"prompt": "draw a hat",
"resolution": "1024x1024",
"batch_size": 1
}

View File

@ -0,0 +1,17 @@
{
"model": {
"name": "Flux 1 SRPO USO 12B",
"architecture": "flux",
"description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process",
"modules": [ "flux_dev_uso"],
"URLs": "flux_srpo",
"loras": "flux_dev_uso",
"flux-model": "flux-dev-uso"
},
"prompt": "the man is wearing a hat",
"embedded_guidance_scale": 4,
"resolution": "1024x1024",
"batch_size": 1
}

19
defaults/lucy_edit.json Normal file
View File

@ -0,0 +1,19 @@
{
"model": {
"name": "Wan2.2 Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
],
"group": "wan2_2"
},
"prompt": "change the clothes to red",
"video_length": 81,
"guidance_scale": 5,
"flow_shift": 5,
"num_inference_steps": 30,
"resolution": "1280x720"
}

View File

@ -0,0 +1,16 @@
{
"model": {
"name": "Wan2.2 FastWan Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
"URLs": "lucy_edit",
"group": "wan2_2",
"loras": "ti2v_2_2_fastwan"
},
"prompt": "change the clothes to red",
"video_length": 81,
"guidance_scale": 1,
"flow_shift": 3,
"num_inference_steps": 5,
"resolution": "1280x720"
}

View File

@ -7,11 +7,10 @@
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors"
],
"preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"],
"attention": {
"<89": "sdpa"
},
"reference_image": true,
"image_outputs": true
}
},
"prompt": "add a hat",
"resolution": "1280x720",

View File

@ -0,0 +1,17 @@
{
"model": {
"name": "Qwen Image Edit Plus 20B",
"architecture": "qwen_image_edit_plus_20B",
"description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors"
],
"preload_URLs": "qwen_image_edit_20B",
"attention": {
"<89": "sdpa"
}
},
"prompt": "add a hat",
"resolution": "1024x1024",
"batch_size": 1
}

View File

@ -4,7 +4,7 @@
"name": "Wan2.1 Standin 14B",
"modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]],
"architecture" : "standin",
"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of person face to transfer this person in the Video.",
"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of a person face to transfer this person in the Video.",
"URLs": "t2v"
}
}

View File

@ -7,6 +7,7 @@
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2"
},
"prompt" : "Put the person into a clown outfit.",
"video_length": 121,
"guidance_scale": 1,
"flow_shift": 3,

View File

@ -0,0 +1,24 @@
{
"model": {
"name": "Wan2.2 Vace Fun 14B",
"architecture": "vace_14B",
"description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors"
],
"URLs2": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors"
],
"group": "wan2_2"
},
"guidance_phases": 2,
"num_inference_steps": 30,
"guidance_scale": 1,
"guidance2_scale": 1,
"flow_shift": 2,
"switch_threshold": 875
}

View File

@ -0,0 +1,28 @@
{
"model": {
"name": "Wan2.2 Vace Fun Cocktail 14B",
"architecture": "vace_14B",
"description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2",
"URLs": "vace_fun_14B_2_2",
"URLs2": "vace_fun_14B_2_2",
"loras": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
],
"loras_multipliers": [
1,
0.2,
0.5,
0.5
],
"group": "wan2_2"
},
"guidance_phases": 2,
"num_inference_steps": 10,
"guidance_scale": 1,
"guidance2_scale": 1,
"flow_shift": 2,
"switch_threshold": 875
}

146
docs/AMD-INSTALLATION.md Normal file
View File

@ -0,0 +1,146 @@
# Installation Guide
This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs
running under Windows.
tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair).
Currently supported (but not necessary tested):
**gfx110x**:
* Radeon RX 7600
* Radeon RX 7700 XT
* Radeon RX 7800 XT
* Radeon RX 7900 GRE
* Radeon RX 7900 XT
* Radeon RX 7900 XTX
**gfx1151**:
* Ryzen 7000 series APUs (Phoenix)
* Ryzen Z1 (e.g., handheld devices like the ROG Ally)
**gfx1201**:
* Ryzen 8000 series APUs (Strix Point)
* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop
## Requirements
- Python 3.11 (3.12 might work, 3.10 definately will not!)
## Installation Environment
This installation uses PyTorch 2.7.0 because that's what currently available in
terms of pre-compiled wheels.
### Installing Python
Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test.
After installing, make sure `python --version` works in your terminal and returns 3.11.x
If not, you probably need to fix your PATH. Go to:
* Windows + Pause/Break
* Advanced System Settings
* Environment Variables
* Edit your `Path` under User Variables
Example correct entries:
```cmd
C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\
C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\
C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\
```
If that doesnt work, scream into a bucket.
### Installing Git
Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine.
## Install (Windows, using `venv`)
### Step 1: Download and Set Up Environment
```cmd
:: Navigate to your desired install directory
cd \your-path-to-wan2gp
:: Clone the repository
git clone https://github.com/deepbeepmeep/Wan2GP.git
cd Wan2GP
:: Create virtual environment using Python 3.10.9
python -m venv wan2gp-env
:: Activate the virtual environment
wan2gp-env\Scripts\activate
```
### Step 2: Install PyTorch
The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says:
**Pytorch wheels for gfx110x, gfx1151, and gfx1201**
Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming.
Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter.
```cmd
pip install ^
https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^
https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^
https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl
```
### Step 3: Install Dependencies
```cmd
:: Install core dependencies
pip install -r requirements.txt
```
## Attention Modes
WanGP supports several attention implementations, only one of which will work for you:
- **SDPA** (default): Available by default with PyTorch. This uses the built-in aotriton accel library, so is actually pretty fast.
## Performance Profiles
Choose a profile based on your hardware:
- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model
- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement
## Running Wan2GP
In future, you will have to do this:
```cmd
cd \path-to\wan2gp
wan2gp\Scripts\activate.bat
python wgp.py
```
For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment)
## Troubleshooting
- If you use a HIGH VRAM mode, don't be a fool. Make sure you use VAE Tiled Decoding.
### Memory Issues
- Use lower resolution or shorter videos
- Enable quantization (default)
- Use Profile 4 for lower VRAM usage
- Consider using 1.3B models instead of 14B models
For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md)

View File

@ -1,20 +1,154 @@
# Changelog
## 🔥 Latest News
### July 21 2025: WanGP v7.1
### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
### August 24 2025: WanGP v8.1 - the RAM Liberator
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
### August 21 2025: WanGP v8.01 - the killer of seven
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
Support:
- Video: x264, x264 lossless, x265
- Images: jpeg, png, webp, wbp lossless
Generation Settings are stored in each of the above regardless of the format (that was the hard part).
Also you can now choose different output directories for images and videos.
unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
### August 10 2025: WanGP v7.76 - Faster than the VAE ...
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
### August 8 2025: WanGP v7.73 - Qwen Rebirth
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
### August 6 2025: WanGP v7.71 - Picky, picky
This release comes with two new models :
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
### August 4 2025: WanGP v7.6 - Remuxed
With this new version you won't have any excuse if there is no sound in your video.
*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
As a result you can apply a different sound source on each new video segment when doing a *Continue Video*.
For instance:
- first video part: use Multitalk with two people speaking
- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
Also:
- End Frame support added for LTX Video models
- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
- Flux Krea Dev support
### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2
Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
And this time I really removed Vace Cocktail Light which gave a blurry vision.
### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview
Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
### July 27 2025: WanGP v7.3 : Interlude
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
### July 26 2025: WanGP v7.2 : Ode to Vace
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
Here are some new Vace improvements:
- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
### July 21 2025: WanGP v7.12
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
And Also:
- easier way to select video resolution
- started to optimize Matanyone to reduce VRAM requirements
- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
- Easier way to select video resolution
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :

View File

@ -27,7 +27,7 @@ conda activate wan2gp
### Step 2: Install PyTorch
```shell
# Install PyTorch 2.7.0 with CUDA 12.4
# Install PyTorch 2.7.0 with CUDA 12.8
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
```

View File

@ -13,27 +13,52 @@ class family_handler():
flux_schnell = flux_model == "flux-schnell"
flux_chroma = flux_model == "flux-chroma"
flux_uso = flux_model == "flux-dev-uso"
model_def_output = {
flux_umo = flux_model == "flux-dev-umo"
flux_kontext = flux_model == "flux-dev-kontext"
extra_model_def = {
"image_outputs" : True,
"no_negative_prompt" : not flux_chroma,
}
if flux_chroma:
model_def_output["guidance_max_phases"] = 1
extra_model_def["guidance_max_phases"] = 1
elif not flux_schnell:
model_def_output["embedded_guidance"] = True
extra_model_def["embedded_guidance"] = True
if flux_uso :
model_def_output["any_image_refs_relative_size"] = True
model_def_output["no_background_removal"] = True
model_def_output["image_ref_choices"] = {
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"),
("Up to two Images are Style Images", "IJ")],
"default": "I",
"letters_filter": "IJ",
extra_model_def["any_image_refs_relative_size"] = True
extra_model_def["no_background_removal"] = True
extra_model_def["image_ref_choices"] = {
"choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
("Up to two Images are Style Images", "KIJ")],
"default": "KI",
"letters_filter": "KIJ",
"label": "Reference Images / Style Images"
}
if flux_kontext:
extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
("None", ""),
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
("Conditional Images are People / Objects", "I"),
],
"letters_filter": "KI",
}
extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape"
elif flux_umo:
extra_model_def["image_ref_choices"] = {
"choices": [
("Conditional Images are People / Objects", "I"),
],
"letters_filter": "I",
"visible": False
}
return model_def_output
extra_model_def["fit_into_canvas_image_refs"] = 0
return extra_model_def
@staticmethod
def query_supported_types():
@ -82,7 +107,7 @@ class family_handler():
]
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .flux_main import model_factory
flux_model = model_factory(
@ -107,15 +132,38 @@ class family_handler():
pipe["feature_embedder"] = flux_model.feature_embedder
return flux_model, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso"
if flux_uso and settings_version < 2.29:
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if "I" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("I", "KI")
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.34:
ui_defaults["denoising_strength"] = 1.
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso"
flux_umo = flux_model == "flux-dev-umo"
flux_kontext = flux_model == "flux-dev-kontext"
ui_defaults.update({
"embedded_guidance": 2.5,
})
if model_def.get("reference_image", False):
ui_defaults.update({
"video_prompt_type": "I" if flux_uso else "KI",
})
})
if flux_kontext or flux_uso:
ui_defaults.update({
"video_prompt_type": "KI",
"denoising_strength": 1.,
})
elif flux_umo:
ui_defaults.update({
"video_prompt_type": "I",
"remove_background_images_ref": 0,
})

View File

@ -9,6 +9,9 @@ from shared.utils.utils import calculate_new_dimensions
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
import torchvision.transforms.functional as TVF
import math
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
from .util import (
aspect_ratio_to_height_width,
@ -20,6 +23,35 @@ from .util import (
)
from PIL import Image
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
# 获取原始图像的宽度和高度
image_w, image_h = raw_image.size
# 计算长边和短边
if image_w >= image_h:
new_w = long_size
new_h = int((long_size / image_w) * image_h)
else:
new_h = long_size
new_w = int((long_size / image_h) * image_w)
# 按新的宽高进行等比例缩放
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
target_w = new_w // 16 * 16
target_h = new_h // 16 * 16
# 计算裁剪的起始坐标以实现中心裁剪
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
right = left + target_w
bottom = top + target_h
# 进行中心裁剪
raw_image = raw_image.crop((left, top, right, bottom))
# 转换为 RGB 模式
raw_image = raw_image.convert("RGB")
return raw_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
@ -64,7 +96,7 @@ class model_factory:
# self.name= "flux-schnell"
source = model_def.get("source", None)
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
self.model_def = model_def
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
@ -106,10 +138,12 @@ class model_factory:
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
input_frames= None,
input_masks= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
@ -120,7 +154,8 @@ class model_factory:
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
image_refs_relative_size = 100,
image_refs_relative_size = 100,
denoising_strength = 1.,
**bbargs
):
if self._interrupt:
@ -129,11 +164,17 @@ class model_factory:
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
image_stiching = not self.name in ['flux-dev-uso']
flux_dev_umo = self.name in ['flux-dev-umo']
latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo']
lock_dimensions= False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
if flux_dev_umo:
ref_long_side = 512 if len(input_ref_images) <= 1 else 320
input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images]
lock_dimensions = True
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
@ -142,33 +183,28 @@ class model_factory:
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
if image_stiching:
if latent_stiching:
# latents stiching with resize
if not lock_dimensions :
for i in range(len(input_ref_images)):
w, h = input_ref_images[i].size
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0)
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
# image stiching method
stiched = input_ref_images[0]
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
else:
first_ref = 0
if "K" in video_prompt_type:
# image latents tiling method
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
first_ref = 1
for i in range(first_ref,len(input_ref_images)):
w, h = input_ref_images[i].size
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
elif input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ]
else:
input_ref_images = None
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if flux_dev_uso :
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
@ -187,6 +223,7 @@ class model_factory:
bs=batch_size,
seed=seed,
device=device,
img_mask=image_mask,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
@ -208,13 +245,19 @@ class model_factory:
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
if image_mask is not None:
from shared.utils.utils import convert_image_to_tensor
img_msk_rebuilt = inp["img_msk_rebuilt"]
img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x

View File

@ -190,6 +190,21 @@ class Flux(nn.Module):
v = swap_scale_shift(v)
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")
new_sd[k] = v
# elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."):
# for k,v in sd.items():
# if "double" in k:
# k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_")
# k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_")
# k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_")
# k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_")
# else:
# k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_")
# k = k.replace(".processor.proj_lora.", ".linear2.lora_")
# k = "diffusion_model." + k
# new_sd[k] = v
# from mmgp import safetensors2
# safetensors2.torch_write_file(new_sd, "fff.safetensors")
else:
new_sd = sd
return new_sd

View File

@ -138,10 +138,12 @@ def prepare_kontext(
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
img_mask = None,
) -> tuple[dict[str, Tensor], int, int]:
# load and encode the conditioning image
res_match_output = img_mask is not None
img_cond_seq = None
img_cond_seq_ids = None
if img_cond_list == None: img_cond_list = []
@ -150,10 +152,11 @@ def prepare_kontext(
for cond_no, img_cond in enumerate(img_cond_list):
width, height = img_cond.size
aspect_ratio = width / height
# Kontext is trained on specific resolutions, using one of them is recommended
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
if res_match_output:
width, height = target_width, target_height
else:
# Kontext is trained on specific resolutions, using one of them is recommended
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
width = 2 * int(width / 16)
height = 2 * int(height / 16)
@ -194,6 +197,19 @@ def prepare_kontext(
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}
if img_mask is not None:
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
return_dict.update({
"img_msk_latents": image_mask_latents,
"img_msk_rebuilt": image_mask_rebuilt,
})
img = get_noise(
bs,
target_height,
@ -265,6 +281,9 @@ def denoise(
loras_slists=None,
unpack_latent = None,
joint_pass= False,
img_msk_latents = None,
img_msk_rebuilt = None,
denoising_strength = 1,
):
kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
@ -272,19 +291,38 @@ def denoise(
if callback != None:
callback(-1, None, True)
original_image_latents = None if img_cond_seq is None else img_cond_seq.clone()
original_timesteps = timesteps
morph, first_step = False, 0
if img_msk_latents is not None:
randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.:
first_step = int(len(timesteps) * (1. - denoising_strength))
if not morph:
latent_noise_factor = timesteps[first_step]
latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
img = latents.to(img)
latents = None
timesteps = timesteps[first_step:]
updated_num_steps= len(timesteps) -1
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
update_loras_slists(model, loras_slists, updated_num_steps)
update_loras_slists(model, loras_slists, len(original_timesteps))
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
from mmgp import offload
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
offload.set_step_no_for_lora(model, i)
offload.set_step_no_for_lora(model, first_step + i)
if pipeline._interrupt:
return None
if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t_curr/1000
img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
@ -334,6 +372,14 @@ def denoise(
pred = neg_pred + real_guidance_scale * (pred - neg_pred)
img += (t_prev - t_curr) * pred
if img_msk_latents is not None:
latent_noise_factor = t_prev
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
img = noisy_image * (1-img_msk_latents) + img_msk_latents * img
noisy_image = None
if callback is not None:
preview = unpack_latent(img).transpose(0,1)
callback(i, preview, False)

View File

@ -640,6 +640,38 @@ configs = {
shift_factor=0.1159,
),
),
"flux-dev-umo": ModelSpec(
repo_id="",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
eso= True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}

View File

@ -861,16 +861,11 @@ class HunyuanVideoSampler(Inference):
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
else:
if self.avatar:
w, h = input_ref_images.size
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
if target_width != w or target_height != h:
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
concat_dict = {'mode': 'timecat', 'bias': -1}
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else:
if input_frames != None:
target_height, target_width = input_frames.shape[-3:-1]
target_height, target_width = input_frames.shape[-2:]
elif input_video != None:
target_height, target_width = input_video.shape[-2:]
@ -899,9 +894,10 @@ class HunyuanVideoSampler(Inference):
pixel_value_bg = input_video.unsqueeze(0)
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
if input_frames != None:
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
# pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
# pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
if input_video != None:
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
@ -913,10 +909,11 @@ class HunyuanVideoSampler(Inference):
if pixel_value_bg.shape[2] < frame_num:
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
# pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
bg_latents.mul_(self.vae.config.scaling_factor)

View File

@ -51,11 +51,38 @@ class family_handler():
extra_model_def["tea_cache"] = True
extra_model_def["mag_cache"] = True
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom_edit"]:
extra_model_def["guide_preprocessing"] = {
"selection": ["MV", "PV"],
}
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]:
extra_model_def["mask_preprocessing"] = {
"selection": ["A", "NA"],
"default" : "NA"
}
if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Reference Image", "I")],
"letters_filter":"I",
"visible": False,
}
if base_model_type in ["hunyuan_avatar"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Start Image", "KI")],
"letters_filter":"KI",
"visible": False,
}
extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
extra_model_def["one_image_ref_needed"] = True
if base_model_type in ["hunyuan_i2v"]:
extra_model_def["image_prompt_types_allowed"] = "S"
return extra_model_def
@staticmethod
@ -102,7 +129,7 @@ class family_handler():
}
@staticmethod
def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .hunyuan import HunyuanVideoSampler
from mmgp import offload
@ -137,6 +164,24 @@ class family_handler():
return hunyuan_model, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if settings_version<2.33:
if base_model_type in ["hunyuan_custom_edit"]:
video_prompt_type= ui_defaults["video_prompt_type"]
if "P" in video_prompt_type and "M" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("M","")
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.36:
if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
audio_prompt_type= ui_defaults["audio_prompt_type"]
if "A" not in audio_prompt_type:
audio_prompt_type += "A"
ui_defaults["audio_prompt_type"] = audio_prompt_type
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults["embedded_guidance_scale"]= 6.0
@ -158,6 +203,7 @@ class family_handler():
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "I",
"audio_prompt_type": "A",
})
elif base_model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({
@ -174,4 +220,5 @@ class family_handler():
"skip_steps_start_step_perc": 25,
"video_length": 129,
"video_prompt_type": "KI",
"audio_prompt_type": "A",
})

View File

@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import (
)
@lru_cache
# @lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask

View File

@ -300,9 +300,6 @@ class LTXV:
prefix_size, height, width = input_video.shape[-3:]
else:
if image_start != None:
frame_width, frame_height = image_start.size
if fit_into_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
conditioning_media_paths.append(image_start.unsqueeze(1))
conditioning_start_frames.append(0)
conditioning_control_frames.append(False)
@ -479,14 +476,14 @@ class LTXV:
images = images.sub_(0.5).mul_(2).squeeze(0)
return images
def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs):
def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs):
map = {
"P" : "pose",
"D" : "depth",
"E" : "canny",
}
loras = []
preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs")
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
lora_file_name = ""
for letter, signature in map.items():
if letter in video_prompt_type:

View File

@ -24,7 +24,19 @@ class family_handler():
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 8
extra_model_def["sliding_window"] = True
extra_model_def["image_prompt_types_allowed"] = "TSEV"
extra_model_def["guide_preprocessing"] = {
"selection": ["", "PV", "DV", "EV", "V"],
"labels" : { "V": "Use LTXV raw format"}
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A", "NA", "XA", "XNA"],
}
extra_model_def["extra_control_frames"] = 1
extra_model_def["dont_cat_preguide"]= True
return extra_model_def
@staticmethod
@ -64,7 +76,7 @@ class family_handler():
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .ltxv import LTXV
ltxv_model = LTXV(

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mmgp import offload
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
@ -28,7 +27,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler
from PIL import Image
from shared.utils.utils import calculate_new_dimensions
from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
XLA_AVAILABLE = False
@ -201,7 +200,8 @@ class QwenImagePipeline(): #DiffusionPipeline
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 1024
if processor is not None:
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
# self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 64
else:
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
@ -233,6 +233,21 @@ class QwenImagePipeline(): #DiffusionPipeline
txt = [template.format(e) for e in prompt]
if self.processor is not None and image is not None:
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
if isinstance(image, list):
base_img_prompt = ""
for i, img in enumerate(image):
base_img_prompt += img_prompt_template.format(i + 1)
elif image is not None:
base_img_prompt = img_prompt_template.format(1)
else:
base_img_prompt = ""
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]
model_inputs = self.processor(
text=txt,
images=image,
@ -387,7 +402,8 @@ class QwenImagePipeline(): #DiffusionPipeline
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
@ -464,7 +480,7 @@ class QwenImagePipeline(): #DiffusionPipeline
def prepare_latents(
self,
image,
images,
batch_size,
num_channels_latents,
height,
@ -479,30 +495,33 @@ class QwenImagePipeline(): #DiffusionPipeline
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, 1, num_channels_latents, height, width)
shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None
if image is not None:
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
if images is not None and len(images ) > 0:
if not isinstance(images, list):
images = [images]
all_image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
image_latents = self._pack_latents(image_latents)
all_image_latents.append(image_latents)
image_latents = torch.cat(all_image_latents, dim=1)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@ -511,7 +530,7 @@ class QwenImagePipeline(): #DiffusionPipeline
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = self._pack_latents(latents)
else:
latents = latents.to(device=device, dtype=dtype)
@ -563,10 +582,15 @@ class QwenImagePipeline(): #DiffusionPipeline
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
image = None,
image_mask = None,
denoising_strength = 0,
callback=None,
pipeline=None,
loras_slists=None,
joint_pass= True,
lora_inpaint = False,
outpainting_dims = None,
qwen_edit_plus = False,
):
r"""
Function invoked when calling the pipeline for generation.
@ -682,33 +706,54 @@ class QwenImagePipeline(): #DiffusionPipeline
batch_size = prompt_embeds.shape[0]
device = "cuda"
prompt_image = None
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image = image[0] if isinstance(image, list) else image
image_height, image_width = self.image_processor.get_default_height_width(image)
aspect_ratio = image_width / image_height
if False :
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
ref_height, ref_width = 1568, 672
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
if image_height * image_width > ref_height * ref_width:
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
condition_images = []
vae_image_sizes = []
vae_images = []
image_mask_latents = None
ref_size = 1024
ref_text_encoder_size = 384 if qwen_edit_plus else 1024
if image is not None:
if not isinstance(image, list): image = [image]
if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width))
for ref_no, img in enumerate(image):
image_width, image_height = img.size
any_mask = ref_no == 0 and image_mask is not None
if (image_height * image_width > ref_size * ref_size) and not any_mask:
vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of)
else:
vae_height, vae_width = image_height, image_width
vae_width = vae_width // multiple_of * multiple_of
vae_height = vae_height // multiple_of * multiple_of
vae_image_sizes.append((vae_width, vae_height))
condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of)
condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) )
if img.size != (vae_width, vae_height):
img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS)
if any_mask :
if lora_inpaint:
image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
img = convert_image_to_tensor(img)
green = torch.tensor([-1.0, 1.0, -1.0]).to(img)
green_image = green[:, None, None] .expand_as(img)
img = torch.where(image_mask_rebuilt > 0, green_image, img)
img = convert_tensor_to_image(img)
else:
image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
image_mask_latents = self._pack_latents(image_mask_latents)
# img.save("nnn.png")
vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) )
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width)
image = image.unsqueeze(2)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
image=prompt_image,
image=condition_images,
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
@ -718,7 +763,7 @@ class QwenImagePipeline(): #DiffusionPipeline
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
image=prompt_image,
image=condition_images,
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
@ -734,7 +779,7 @@ class QwenImagePipeline(): #DiffusionPipeline
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels // 4
latents, image_latents = self.prepare_latents(
image,
vae_images,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@ -744,11 +789,18 @@ class QwenImagePipeline(): #DiffusionPipeline
generator,
latents,
)
original_image_latents = None if image_latents is None else image_latents.clone()
if image is not None:
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
(1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
# (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
*[
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
for vae_width, vae_height in vae_image_sizes
],
]
] * batch_size
else:
@ -773,7 +825,7 @@ class QwenImagePipeline(): #DiffusionPipeline
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
original_timesteps = timesteps
# handle guidance
if self.transformer.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
@ -788,56 +840,80 @@ class QwenImagePipeline(): #DiffusionPipeline
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
morph, first_step = False, 0
lanpaint_proc = None
if image_mask_latents is not None:
randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.:
first_step = int(len(timesteps) * (1. - denoising_strength))
if not morph:
latent_noise_factor = timesteps[first_step]/1000
# latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
# from shared.inpainting.lanpaint import LanPaint
# lanpaint_proc = LanPaint()
# 6. Denoising loop
self.scheduler.set_begin_index(0)
updated_num_steps= len(timesteps)
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
update_loras_slists(self.transformer, loras_slists, updated_num_steps)
update_loras_slists(self.transformer, loras_slists, len(original_timesteps))
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
for i, t in enumerate(timesteps):
offload.set_step_no_for_lora(self.transformer, first_step + i)
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
if do_true_cfg and joint_pass:
noise_pred, neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)
if noise_pred == None: return None
noise_pred = noise_pred[:, : latents.size(1)]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
else:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if noise_pred == None: return None
noise_pred = noise_pred[:, : latents.size(1)]
latents_dtype = latents.dtype
# latent_model_input = latents
def denoise(latent_model_input, true_cfg_scale):
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
do_true_cfg = true_cfg_scale > 1
if do_true_cfg and joint_pass:
noise_pred, neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance, #!!!!
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)
if noise_pred == None: return None, None
noise_pred = noise_pred[:, : latents.size(1)]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
else:
neg_noise_pred = None
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if noise_pred == None: return None, None
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
neg_noise_pred = self.transformer(
@ -851,20 +927,43 @@ class QwenImagePipeline(): #DiffusionPipeline
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if neg_noise_pred == None: return None
if neg_noise_pred == None: return None, None
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
return noise_pred, neg_noise_pred
def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
if do_true_cfg:
comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
if comb_pred == None: return None
if do_true_cfg:
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
if comb_pred == None: return None
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
neg_noise_pred = None
return noise_pred
if lanpaint_proc is not None and i<=3:
latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
if latents is None: return None
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None
noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
neg_noise_pred = None
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
noise_pred = None
if image_mask_latents is not None:
if lanpaint_proc is not None:
latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents
else:
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
latent_noise_factor = next_t / 1000
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
noisy_image = None
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
@ -872,13 +971,14 @@ class QwenImagePipeline(): #DiffusionPipeline
latents = latents.to(latents_dtype)
if callback is not None:
# preview = unpack_latent(img).transpose(0,1)
callback(i, None, False)
preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
preview = preview.squeeze(0)
callback(i, preview, False)
self._current_timestep = None
if output_type == "latent":
image = latents
output_image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents.to(self.vae.dtype)
@ -891,7 +991,9 @@ class QwenImagePipeline(): #DiffusionPipeline
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None):
output_image = vae_images[0].squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(vae_images[0] ) * image_mask_rebuilt
return image
return output_image

View File

@ -1,4 +1,6 @@
import torch
import gradio as gr
def get_qwen_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors"
@ -9,20 +11,51 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
class family_handler():
@staticmethod
def query_model_def(base_model_type, model_def):
model_def_output = {
extra_model_def = {
"image_outputs" : True,
"sample_solvers":[
("Default", "default"),
("Lightning", "lightning")],
"guidance_max_phases" : 1,
"fit_into_canvas_image_refs": 0,
}
if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
("None", ""),
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
("Conditional Images are People / Objects", "I"),
],
"letters_filter": "KI",
}
extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape"
extra_model_def["video_guide_outpainting"] = [2]
extra_model_def["model_modes"] = {
"choices": [
("Lora Inpainting: Inpainted area completely unrelated to occulted content", 1),
("Masked Denoising : Inpainted area may reuse some content that has been occulted", 0),
],
"default": 1,
"label" : "Inpainting Method",
"image_modes" : [2],
}
return model_def_output
if base_model_type in ["qwen_image_edit_plus_20B"]:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "PV", "SV", "CV"],
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A"],
"visible": False,
}
return extra_model_def
@staticmethod
def query_supported_types():
return ["qwen_image_20B", "qwen_image_edit_20B"]
return ["qwen_image_20B", "qwen_image_edit_20B", "qwen_image_edit_plus_20B"]
@staticmethod
def query_family_maps():
@ -46,7 +79,7 @@ class family_handler():
}
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .qwen_main import model_factory
from mmgp import offload
@ -74,14 +107,44 @@ class family_handler():
if ui_defaults.get("sample_solver", "") == "":
ui_defaults["sample_solver"] = "default"
if settings_version < 2.32:
ui_defaults["denoising_strength"] = 1.
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"guidance_scale": 4,
"sample_solver": "default",
})
if model_def.get("reference_image", False):
if base_model_type in ["qwen_image_edit_20B"]:
ui_defaults.update({
"video_prompt_type": "KI",
"denoising_strength" : 1.,
"model_mode" : 0,
})
elif base_model_type in ["qwen_image_edit_plus_20B"]:
ui_defaults.update({
"video_prompt_type": "I",
"denoising_strength" : 1.,
"model_mode" : 0,
})
@staticmethod
def validate_generative_settings(base_model_type, model_def, inputs):
if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
model_mode = inputs["model_mode"]
denoising_strength= inputs["denoising_strength"]
video_guide_outpainting= inputs["video_guide_outpainting"]
from wgp import get_outpainting_dims
outpainting_dims = get_outpainting_dims(video_guide_outpainting)
if denoising_strength < 1 and model_mode == 1:
gr.Info("Denoising Strength will be ignored while using Lora Inpainting")
if outpainting_dims is not None and model_mode == 0 :
return "Outpainting is not supported with Masked Denoising "
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("qwen")
return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler
from .pipeline_qwenimage import QwenImagePipeline
from PIL import Image
from shared.utils.utils import calculate_new_dimensions
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
@ -44,17 +44,17 @@ class model_factory():
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
mixed_precision_transformer = False,
):
transformer_filename = model_filename[0]
processor = None
tokenizer = None
if base_model_type == "qwen_image_edit_20B":
if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
self.base_model_type = base_model_type
base_config_file = "configs/qwen_image_20B.json"
with open(base_config_file, 'r', encoding='utf-8') as f:
@ -103,6 +103,8 @@ class model_factory():
n_prompt = None,
sampling_steps: int = 20,
input_ref_images = None,
input_frames= None,
input_masks= None,
width= 832,
height=480,
guide_scale: float = 4,
@ -114,6 +116,9 @@ class model_factory():
VAE_tile_size = None,
joint_pass = True,
sample_solver='default',
denoising_strength = 1.,
model_mode = 0,
outpainting_dims = None,
**bbargs
):
# Generate with different aspect ratios
@ -168,13 +173,17 @@ class model_factory():
self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1]
qwen_edit_plus = self.base_model_type in ["qwen_image_edit_plus_20B"]
self.vae.enable_slicing()
# width, height = aspect_ratios["16:9"]
if n_prompt is None or len(n_prompt) == 0:
n_prompt= "text, watermark, copyright, blurry, low resolution"
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ] + ([] if input_ref_images is None else input_ref_images )
if input_ref_images is not None:
# image stiching method
stiched = input_ref_images[0]
@ -182,14 +191,16 @@ class model_factory():
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
if not qwen_edit_plus:
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
image = self.pipeline(
prompt=input_prompt,
negative_prompt=n_prompt,
image = input_ref_images,
image_mask = image_mask,
width=width,
height=height,
num_inference_steps=sampling_steps,
@ -199,8 +210,19 @@ class model_factory():
pipeline=self,
loras_slists=loras_slists,
joint_pass = joint_pass,
generator=torch.Generator(device="cuda").manual_seed(seed)
)
denoising_strength=denoising_strength,
generator=torch.Generator(device="cuda").manual_seed(seed),
lora_inpaint = image_mask is not None and model_mode == 1,
outpainting_dims = outpainting_dims,
qwen_edit_plus = qwen_edit_plus,
)
if image is None: return None
return image.transpose(0, 1)
def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
if model_mode == 0: return [], []
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
if len(preloadURLs) == 0: return [], []
return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]

View File

@ -204,7 +204,7 @@ class QwenEmbedRope(nn.Module):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if not torch.compiler.is_compiling():
if not torch.compiler.is_compiling() and False:
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
@ -224,7 +224,6 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)

View File

@ -0,0 +1,143 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import numbers
from peft import LoraConfig
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
target_modules = []
for name, module in transformer.named_modules():
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
target_modules.append(name)
transformer_lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights=init_lora_weights,
target_modules=target_modules,
)
return transformer_lora_config
class TensorList(object):
def __init__(self, tensors):
"""
tensors: a list of torch.Tensor objects. No need to have uniform shape.
"""
assert isinstance(tensors, (list, tuple))
assert all(isinstance(u, torch.Tensor) for u in tensors)
assert len(set([u.ndim for u in tensors])) == 1
assert len(set([u.dtype for u in tensors])) == 1
assert len(set([u.device for u in tensors])) == 1
self.tensors = tensors
def to(self, *args, **kwargs):
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
def size(self, dim):
assert dim == 0, 'only support get the 0th size'
return len(self.tensors)
def pow(self, *args, **kwargs):
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
def squeeze(self, dim):
assert dim != 0
if dim > 0:
dim -= 1
return TensorList([u.squeeze(dim) for u in self.tensors])
def type(self, *args, **kwargs):
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
def type_as(self, other):
assert isinstance(other, (torch.Tensor, TensorList))
if isinstance(other, torch.Tensor):
return TensorList([u.type_as(other) for u in self.tensors])
else:
return TensorList([u.type(other.dtype) for u in self.tensors])
@property
def dtype(self):
return self.tensors[0].dtype
@property
def device(self):
return self.tensors[0].device
@property
def ndim(self):
return 1 + self.tensors[0].ndim
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def __add__(self, other):
return self._apply(other, lambda u, v: u + v)
def __radd__(self, other):
return self._apply(other, lambda u, v: v + u)
def __sub__(self, other):
return self._apply(other, lambda u, v: u - v)
def __rsub__(self, other):
return self._apply(other, lambda u, v: v - u)
def __mul__(self, other):
return self._apply(other, lambda u, v: u * v)
def __rmul__(self, other):
return self._apply(other, lambda u, v: v * u)
def __floordiv__(self, other):
return self._apply(other, lambda u, v: u // v)
def __truediv__(self, other):
return self._apply(other, lambda u, v: u / v)
def __rfloordiv__(self, other):
return self._apply(other, lambda u, v: v // u)
def __rtruediv__(self, other):
return self._apply(other, lambda u, v: v / u)
def __pow__(self, other):
return self._apply(other, lambda u, v: u ** v)
def __rpow__(self, other):
return self._apply(other, lambda u, v: v ** u)
def __neg__(self):
return TensorList([-u for u in self.tensors])
def __iter__(self):
for tensor in self.tensors:
yield tensor
def __repr__(self):
return 'TensorList: \n' + repr(self.tensors)
def _apply(self, other, op):
if isinstance(other, (list, tuple, TensorList)) or (
isinstance(other, torch.Tensor) and (
other.numel() > 1 or other.ndim > 1
)
):
assert len(other) == len(self.tensors)
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
elif isinstance(other, numbers.Number) or (
isinstance(other, torch.Tensor) and (
other.numel() == 1 and other.ndim <= 1
)
):
return TensorList([op(u, other) for u in self.tensors])
else:
raise TypeError(
f'unsupported operand for *: "TensorList" and "{type(other)}"'
)

View File

@ -0,0 +1,382 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from shared.attention import pay_attention
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="torch",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
if use_context_parallel:
q = gather_forward(q, dim=1)
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
# Compute attention.
# Size([batches, tokens, heads, head_features])
qkv_list = [q, k, v]
del q,k,v
attn = pay_attention(qkv_list)
# attn = attention(
# q,
# k,
# v,
# max_seqlen_q=q.shape[1],
# batch_size=q.shape[0],
# )
attn = attn.reshape(*attn.shape[:2], -1)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
# if use_context_parallel:
# attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output

View File

@ -0,0 +1,31 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = self.pose_patch_embedding(pose_latents)
x[:, :, 1:] += pose_latents
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec

View File

@ -0,0 +1,308 @@
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
def custom_qr(input_tensor):
original_dtype = input_tensor.dtype
if original_dtype in [torch.bfloat16, torch.float16]:
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
return q.to(original_dtype), r.to(original_dtype)
return torch.linalg.qr(input_tensor)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, self.kernel, pad=self.pad)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
class EqualConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
bias=bias and not activate))
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class EncoderApp(nn.Module):
def __init__(self, size, w_dim=512):
super(EncoderApp, self).__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
self.convs = nn.ModuleList()
self.convs.append(ConvLayer(3, channels[size], 1))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
def forward(self, x):
res = []
h = x
for conv in self.convs:
h = conv(h)
res.append(h)
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
class Encoder(nn.Module):
def __init__(self, size, dim=512, dim_motion=20):
super(Encoder, self).__init__()
# appearance netmork
self.net_app = EncoderApp(size, dim)
# motion network
fc = [EqualLinear(dim, dim)]
for i in range(3):
fc.append(EqualLinear(dim, dim))
fc.append(EqualLinear(dim, dim_motion))
self.fc = nn.Sequential(*fc)
def enc_app(self, x):
h_source = self.net_app(x)
return h_source
def enc_motion(self, x):
h, _ = self.net_app(x)
h_motion = self.fc(h)
return h_motion
class Direction(nn.Module):
def __init__(self, motion_dim):
super(Direction, self).__init__()
self.weight = nn.Parameter(torch.randn(512, motion_dim))
def forward(self, input):
weight = self.weight + 1e-8
Q, R = custom_qr(weight)
if input is None:
return Q
else:
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
out = torch.matmul(input_diag, Q.T)
out = torch.sum(out, dim=1)
return out
class Synthesis(nn.Module):
def __init__(self, motion_dim):
super(Synthesis, self).__init__()
self.direction = Direction(motion_dim)
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=20):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim)
self.dec = Synthesis(motion_dim)
def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img)
# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with torch.cuda.amp.autocast(dtype=torch.float32):
motion_feat = self.enc.enc_motion(img)
motion = self.dec.direction(motion_feat)
return motion

View File

@ -19,7 +19,8 @@ from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .modules.model import WanModel, clear_caches
from .modules.model import WanModel
from mmgp.offload import get_cache, clear_caches
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .modules.vae2_2 import Wan2_2_VAE
@ -31,9 +32,11 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
from shared.utils.vace_preprocessor import VaceVideoProcessor
from shared.utils.basic_flowmatch import FlowMatchScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
from shared.utils.audio_video import save_video
from mmgp import safetensors2
from shared.utils.audio_video import save_video
def optimized_scale(positive_flat, negative_flat):
@ -63,6 +66,7 @@ class WanAny2V:
config,
checkpoint_dir,
model_filename = None,
submodel_no_list = None,
model_type = None,
model_def = None,
base_model_type = None,
@ -91,7 +95,7 @@ class WanAny2V:
shard_fn= None)
# base_model_type = "i2v2_2"
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]:
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]:
self.clip = CLIPModel(
dtype=config.clip_dtype,
device=self.device,
@ -100,7 +104,7 @@ class WanAny2V:
tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
if base_model_type in ["ti2v_2_2"]:
if base_model_type in ["ti2v_2_2", "lucy_edit"]:
self.vae_stride = (4, 16, 16)
vae_checkpoint = "Wan2.2_VAE.safetensors"
vae = Wan2_2_VAE
@ -125,75 +129,82 @@ class WanAny2V:
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
# model_filename[1] = xmodel_filename
self.model = self.model2 = None
source = model_def.get("source", None)
source2 = model_def.get("source2", None)
module_source = model_def.get("module_source", None)
module_source2 = model_def.get("module_source2", None)
if module_source is not None:
model_filename = [] + model_filename
model_filename[1] = module_source
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
elif source is not None:
self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
if module_source2 is not None:
self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
if source is not None:
self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
elif self.transformer_switch:
shared_modules= {}
self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules)
self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
shared_modules = None
else:
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
# self.model = offload.load_model_data(self.model, xmodel_filename )
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
if source2 is not None:
self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
if self.model is not None or self.model2 is not None:
from wgp import save_model
from mmgp.safetensors2 import torch_load_file
else:
if self.transformer_switch:
if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]:
raise Exception("Shared and non shared modules at the same time across multipe models is not supported")
if 0 in submodel_no_list[2:]:
shared_modules= {}
self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules)
self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
shared_modules = None
else:
modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ]
modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ]
self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
else:
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
if self.model is not None:
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
self.model.eval().requires_grad_(False)
if self.model2 is not None:
self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model2, dtype, True)
# offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
# offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file)
# offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
self.model.eval().requires_grad_(False)
if self.model2 is not None:
self.model2.eval().requires_grad_(False)
if module_source is not None:
from wgp import save_model
from mmgp.safetensors2 import torch_load_file
filter = list(torch_load_file(module_source))
save_model(self.model, model_type, dtype, None, is_module=True, filter=filter)
elif not source is None:
from wgp import save_model
save_model(self.model, model_type, dtype, None)
save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1)
if module_source2 is not None:
save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2)
if not source is None:
save_model(self.model, model_type, dtype, None, submodel_no= 1)
if not source2 is None:
save_model(self.model2, model_type, dtype, None, submodel_no= 2)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
if self.model is not None:
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
if self.model2 is not None:
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
self.sample_neg_prompt = config.sample_neg_prompt
if self.model.config.get("vace_in_dim", None) != None:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
if hasattr(self.model, "vace_blocks"):
self.adapt_vace_model(self.model)
if self.model2 is not None: self.adapt_vace_model(self.model2)
if hasattr(self.model, "face_adapter"):
self.adapt_animate_model(self.model)
if self.model2 is not None: self.adapt_animate_model(self.model2)
self.num_timesteps = 1000
self.use_timestep_transform = True
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
ref_images = [ref_images] * len(frames)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
@ -225,11 +236,7 @@ class WanAny2V:
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
ref_images = [ref_images] * len(masks)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
@ -257,119 +264,6 @@ class WanAny2V:
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
from shared.utils.utils import save_image
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img) if return_mask else None
else:
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
else:
canvas_height, canvas_width = image_size
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
if fill_max and (canvas_height - new_height) < 16:
new_height = canvas_height
if fill_max and (canvas_width - new_width) < 16:
new_width = canvas_width
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
return ref_img.to(device), canvas
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = []
trim_video_guide = len(keep_video_guide_frames)
def conv_tensor(t, device):
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
if sub_src_mask is not None and sub_src_video is not None:
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[i][:, pos:pos+1] = 0
src_mask[i][:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
self.background_mask = None
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None and not torch.is_tensor(ref_img):
if j==0 and any_background_ref:
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
else:
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
if self.background_mask != None:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images
def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = []
@ -380,12 +274,28 @@ class WanAny2V:
return torch.cat(ref_vae_latents, dim=1)
def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest')
if nb_frames_unchanged >0:
msk[:, :nb_frames_unchanged] = 1
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1,2)[0]
return msk
def generate(self,
input_prompt,
input_frames= None,
input_frames2= None,
input_masks = None,
input_ref_images = None,
input_masks2 = None,
input_ref_images = None,
input_ref_masks = None,
input_faces = None,
input_video = None,
image_start = None,
image_end = None,
@ -453,7 +363,8 @@ class WanAny2V:
timesteps.append(0.)
timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
if self.use_timestep_transform:
timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
timesteps = torch.tensor(timesteps)
sample_scheduler = None
elif sample_solver == 'causvid':
sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
@ -496,6 +407,8 @@ class WanAny2V:
text_len = self.model.text_len
context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0)
if input_video is not None: height, width = input_video.shape[-2:]
# NAG_prompt = "static, low resolution, blurry"
# context_NAG = self.text_encoder([NAG_prompt], self.device)[0]
# context_NAG = context_NAG.to(self.dtype)
@ -516,109 +429,76 @@ class WanAny2V:
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"]
ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
lucy_edit= model_type in ["lucy_edit"]
animate= model_type in ["animate"]
start_step_no = 0
ref_images_count = 0
trim_frames = 0
extended_overlapped_latents = None
extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None
no_noise_latents_injection = infinitetalk
timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
extended_input_dim = 0
ref_images_before = False
# image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False
if image_start is None:
if infinitetalk:
if input_frames is not None:
image_ref = input_frames[:, -1]
if input_video is None: input_video = input_frames[:, -1:]
new_shot = "Q" in video_prompt_type
else:
if pre_video_frame is None:
new_shot = True
else:
if input_ref_images is None:
input_ref_images, new_shot = [pre_video_frame], False
else:
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
if input_ref_images is None: raise Exception("Missing Reference Image")
new_shot = new_shot and window_no <= len(input_ref_images)
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot:
input_video = image_ref.unsqueeze(1)
else:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
image_for_clip = image_ref.to(input_video)
control_pre_frames_count = 1
control_video = image_for_clip.unsqueeze(1)
if infinitetalk:
new_shot = "Q" in video_prompt_type
if input_frames is not None:
image_ref = input_frames[:, 0]
else:
image_for_clip = input_video[:, -1]
control_pre_frames_count = preframes_count
control_video = input_video
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
if hasattr(self, "clip"):
clip_image_size = self.clip.model.image_size
clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :]
clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ])
clip_image = None
if input_ref_images is None:
if pre_video_frame is None: raise Exception("Missing Reference Image")
input_ref_images, new_shot = [pre_video_frame], False
new_shot = new_shot and window_no <= len(input_ref_images)
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot or input_video is None:
input_video = image_ref.unsqueeze(1)
else:
clip_context = None
enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width),
device=self.device, dtype= self.VAE_dtype)],
dim = 1).to(self.device)
color_reference_frame = image_for_clip.unsqueeze(1).clone()
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
image_start = image_ref.to(input_video)
control_pre_frames_count = 1
control_video = image_start.unsqueeze(1)
else:
preframes_count = control_pre_frames_count = 1
any_end_frame = image_end is not None
add_frames_for_end_image = any_end_frame and model_type == "i2v"
if any_end_frame:
if add_frames_for_end_image:
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
trim_frames = 1
height, width = image_start.shape[1:]
image_start = input_video[:, -1]
control_pre_frames_count = preframes_count
control_video = input_video
lat_h = round(
height // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
width // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
height = lat_h * self.vae_stride[1]
width = lat_w * self.vae_stride[2]
image_start_frame = image_start.unsqueeze(1).to(self.device)
color_reference_frame = image_start_frame.clone()
if image_end is not None:
img_end_frame = image_end.unsqueeze(1).to(self.device)
color_reference_frame = image_start.unsqueeze(1).clone()
if hasattr(self, "clip"):
clip_image_size = self.clip.model.image_size
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
if model_type == "flf2v_720p":
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([image_start[:, None, :, :]])
else:
clip_context = None
any_end_frame = image_end is not None
add_frames_for_end_image = any_end_frame and model_type == "i2v"
if any_end_frame:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
if add_frames_for_end_image:
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
trim_frames = 1
if any_end_frame:
enc= torch.concat([
image_start_frame,
torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype),
img_end_frame,
], dim=1).to(self.device)
else:
enc= torch.concat([
image_start_frame,
torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype)
], dim=1).to(self.device)
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None
if image_end is not None:
img_end_frame = image_end.unsqueeze(1).to(self.device)
clip_image_start, clip_image_end = image_start, image_end
if any_end_frame:
enc= torch.concat([
control_video,
torch.zeros( (3, frame_num-control_pre_frames_count-1, height, width), device=self.device, dtype= self.VAE_dtype),
img_end_frame,
], dim=1).to(self.device)
else:
enc= torch.concat([
control_video,
torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype)
], dim=1).to(self.device)
image_start = image_end = img_end_frame = image_ref = control_video = None
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
if any_end_frame:
@ -643,42 +523,85 @@ class WanAny2V:
if infinitetalk:
lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
# if control_pre_frames_count != pre_frames_count:
lat_y = input_video = None
kwargs.update({ 'y': y})
if not clip_context is None:
kwargs.update({'clip_fea': clip_context})
# Recam Master
# Animate
if animate:
pose_pixels = input_frames * input_masks
input_masks = 1. - input_masks
pose_pixels -= input_masks
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
# save_video(pose_pixels, "pose.mp4")
# save_video(input_frames, "input_frames.mp4")
# save_video(input_masks, "input_masks.mp4", value_range=(0,1))
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
msk = torch.concat([msk_ref, msk_control], dim=1)
image_ref = input_ref_images[0].to(self.device)
clip_image_start = image_ref.squeeze(1)
lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
lat_y = msk = msk_control = msk_ref = pose_pixels = None
ref_images_before = True
ref_images_count = 1
lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
# Clip image
if hasattr(self, "clip") and clip_image_start is not None:
clip_image_size = self.clip.model.image_size
clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
if model_type == "flf2v_720p":
clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
clip_image_start = clip_image_end = None
kwargs.update({'clip_fea': clip_context})
# Recam Master & Lucy Edit
if recam or lucy_edit:
frame_num, height,width = input_frames.shape[-3:]
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
frame_num = (lat_frames -1) * self.vae_stride[0] + 1
input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
extended_input_dim = 2 if recam else 1
del input_frames
if recam:
# should be be in fact in input_frames since it is control video not a video to be extended
target_camera = model_mode
height,width = input_video.shape[-2:]
input_video = input_video.to(dtype=self.dtype , device=self.device)
source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
del input_video
# Process target camera (recammaster)
target_camera = model_mode
from shared.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
kwargs['cam_emb'] = cam_emb
# Video 2 Video
if denoising_strength < 1. and input_frames != None:
if "G" in video_prompt_type and input_frames != None:
height, width = input_frames.shape[-2:]
source_latents = self.vae.encode([input_frames])[0].unsqueeze(0)
injection_denoising_step = 0
inject_from_start = False
if input_frames != None and denoising_strength < 1 :
color_reference_frame = input_frames[:, -1:].clone()
if overlapped_latents != None:
overlapped_latents_frames_num = overlapped_latents.shape[2]
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
if prefix_frames_count > 0:
overlapped_frames_num = prefix_frames_count
overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1
# overlapped_latents_frames_num = overlapped_latents.shape[2]
# overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else:
overlapped_latents_frames_num = overlapped_frames_num = 0
if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) )
latent_keep_frames = []
if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0:
inject_from_start = True
@ -694,14 +617,21 @@ class WanAny2V:
if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
injection_denoising_step = 0
if input_masks is not None and not "U" in video_prompt_type:
image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0)
if image_mask_latents.shape[2] !=1:
image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2)
image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device)
# save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) )
# image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# Phantom
if phantom:
input_ref_images_neg = None
if input_ref_images != None: # Phantom Ref images
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = torch.zeros_like(input_ref_images)
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
trim_frames = input_ref_images.shape[1]
lat_input_ref_images_neg = None
if input_ref_images is not None: # Phantom Ref images
lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
ref_images_count = trim_frames = lat_input_ref_images.shape[1]
if ti2v:
if input_video is None:
@ -710,28 +640,29 @@ class WanAny2V:
height, width = input_video.shape[-2:]
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
timestep_injection = True
if extended_input_dim > 0:
extended_latents[:, :, :source_latents.shape[2]] = source_latents
# Vace
if vace :
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
ref_images_before = True
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
if self.background_mask != None:
color_reference_frame = input_ref_images[0][0].clone()
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(self.background_mask, None)
if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
color_reference_frame = input_ref_images[0].clone()
zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
zz0[:, 0:1] = zzbg
mm0[:, 0:1] = mmbg
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
z = self.vace_latent(z0, m0)
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
zz0 = mm0 = zzbg = mmbg = None
z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
if overlapped_latents != None :
@ -739,17 +670,12 @@ class WanAny2V:
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
if prefix_frames_count > 0:
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
lat_h, lat_w = target_shape[-2:]
height = self.vae_stride[1] * lat_h
width = self.vae_stride[2] * lat_w
else:
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
if multitalk and audio_proj != None:
if multitalk:
if audio_proj is None:
audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ]
from .multitalk.multitalk import get_target_masks
audio_proj = [audio.to(self.dtype) for audio in audio_proj]
human_no = len(audio_proj[0])
@ -764,9 +690,9 @@ class WanAny2V:
expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes
if target_camera != None:
if extended_input_dim>=2:
shape = list(target_shape[1:])
shape[0] *= 2
shape[extended_input_dim-2] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
@ -777,19 +703,20 @@ class WanAny2V:
if standin:
from preprocessing.face_preprocessor import FaceProcessor
standin_ref_pos = 1 if "K" in video_prompt_type else 0
if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image")
standin_ref_pos = -1
image_ref = original_input_ref_images[standin_ref_pos]
image_ref.save("si.png")
# face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2")
face_processor = FaceProcessor()
standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
face_processor = None
gc.collect()
torch.cuda.empty_cache()
standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) ))
standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, })
if len(original_input_ref_images) < standin_ref_pos + 1:
if "I" in video_prompt_type and model_type in ["vace_standin_14B"]:
print("Warning: Missing Standin ref image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images.")
else:
standin_ref_pos = -1
image_ref = original_input_ref_images[standin_ref_pos]
face_processor = FaceProcessor()
standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
face_processor = None
gc.collect()
torch.cuda.empty_cache()
standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) ))
standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, })
# Steps Skipping
@ -819,7 +746,7 @@ class WanAny2V:
denoising_extra = ""
from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps
phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
if len(phases_description) > 0: set_header_text(phases_description)
guidance_switch_done = guidance_switch2_done = False
if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}"
@ -830,8 +757,8 @@ class WanAny2V:
denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}"
callback(step_no-1, denoising_extra = denoising_extra)
return guide_scale, guidance_switch_done, trans, denoising_extra
update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)
def clear():
@ -844,19 +771,22 @@ class WanAny2V:
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
# b, c, lat_f, lat_h, lat_w
latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if "G" in video_prompt_type: randn = latents
if apg_switch != 0:
apg_momentum = -0.75
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
gc.collect()
torch.cuda.empty_cache()
# denoising
trans = self.model
for i, t in enumerate(tqdm(timesteps)):
guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra)
guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra)
offload.set_step_no_for_lora(trans, i)
offload.set_step_no_for_lora(trans, start_step_no + i)
timestep = torch.stack([t])
if timestep_injection:
@ -864,44 +794,43 @@ class WanAny2V:
timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device)
timestep[:source_latents.shape[2]] = 0
kwargs.update({"t": timestep, "current_step": start_step_no + i})
kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i })
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
if denoising_strength < 1 and i <= injection_denoising_step:
sigma = t / 1000
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start:
new_latents = latents.clone()
new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
noisy_image = latents.clone()
noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
for latent_no, keep_latent in enumerate(latent_keep_frames):
if not keep_latent:
new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
latents = new_latents
new_latents = None
noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
latents = noisy_image
noisy_image = None
else:
latents = noise * sigma + (1 - sigma) * source_latents
noise = None
latents = randn * sigma + (1 - sigma) * source_latents
if extended_overlapped_latents != None:
if no_noise_latents_injection:
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents
else:
latent_noise_factor = t / 1000
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
if vace:
overlap_noise_factor = overlap_noise / 1000
for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
if target_camera != None:
latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2)
if extended_input_dim > 0:
latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
else:
latent_model_input = latents
any_guidance = guide_scale != 1
if phantom:
gen_args = {
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"context": [context, context_null, context_null] ,
}
elif fantasy:
@ -1010,8 +939,8 @@ class WanAny2V:
if sample_solver == "euler":
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
dt = dt / self.num_timesteps
latents = latents - noise_pred * dt[:, None, None, None, None]
dt = dt.item() / self.num_timesteps
latents = latents - noise_pred * dt
else:
latents = sample_scheduler.step(
noise_pred[:, :, :target_shape[1]],
@ -1019,9 +948,16 @@ class WanAny2V:
latents,
**scheduler_kwargs)[0]
if image_mask_latents is not None:
sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000
noisy_image = randn * sigma + (1 - sigma) * source_latents
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
if callback is not None:
latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
@ -1032,7 +968,7 @@ class WanAny2V:
if timestep_injection:
latents[:, :, :source_latents.shape[2]] = source_latents
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone()
@ -1069,4 +1005,20 @@ class WanAny2V:
delattr(model, "vace_blocks")
def adapt_animate_model(self, model):
modules_dict= { k: m for k, m in model.named_modules()}
for animate_layer in range(8):
module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
model_layer = animate_layer * 5
target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "face_adapter_fuser_blocks", module )
delattr(model, "face_adapter")
def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs):
if base_model_type == "animate":
if "1" in video_prompt_type:
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
if len(preloadURLs) > 0:
return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
return [], []

View File

@ -21,11 +21,24 @@ class family_handler():
extra_model_def["fps"] =fps
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 20
extra_model_def["latent_size"] = 4
extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True
extra_model_def["guidance_max_phases"] = 1
extra_model_def["model_modes"] = {
"choices": [
("Synchronous", 0),
("Asynchronous (better quality but around 50% extra steps added)", 5),
],
"default": 0,
"label" : "Generation Type"
}
extra_model_def["image_prompt_types_allowed"] = "TSV"
return extra_model_def
@staticmethod
@ -54,7 +67,11 @@ class family_handler():
def query_family_infos():
return {}
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
@ -62,7 +79,7 @@ class family_handler():
return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
from .configs import WAN_CONFIGS
from .wan_handler import family_handler
cfg = WAN_CONFIGS['t2v-14B']

View File

@ -1,479 +0,0 @@
import math
import os
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import numpy as np
import torch
from diffusers.image_processor import PipelineImageInput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from tqdm import tqdm
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class DTT2V:
def __init__(
self,
config,
checkpoint_dir,
rank=0,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16,
):
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json")
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
self.scheduler = FlowUniPCMultistepScheduler()
@property
def do_classifier_free_guidance(self) -> bool:
return self._guidance_scale > 1
def encode_image(
self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# prefix_video
prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
if prefix_video.dtype == torch.uint8:
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
prefix_video = prefix_video.to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
return prefix_video, predix_video_latent_length
def prepare_latents(
self,
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
) -> torch.Tensor:
return randn_tensor(shape, generator, device=device, dtype=dtype)
def generate_timestep_matrix(
self,
num_frames,
step_template,
base_num_frames,
ar_step=5,
num_pre_ready=0,
casual_block_size=1,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while torch.all(pre_row >= (num_iterations - 1)) == False:
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (
num_iterations - 1
): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append(
(new_row != pre_row) & (new_row != num_iterations)
) # False: no need to update True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
return step_matrix, step_index, step_update_mask, valid_interval
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]] = "",
image: PipelineImageInput = None,
height: int = 480,
width: int = 832,
num_frames: int = 97,
num_inference_steps: int = 50,
shift: float = 1.0,
guidance_scale: float = 5.0,
seed: float = 0.0,
overlap_history: int = 17,
addnoise_condition: int = 0,
base_num_frames: int = 97,
ar_step: int = 5,
causal_block_size: int = 1,
causal_attention: bool = False,
fps: int = 24,
VAE_tile_size = 0,
joint_pass = False,
callback = None,
):
generator = torch.Generator(device=self.device)
generator.manual_seed(seed)
# if base_num_frames > base_num_frames:
# causal_block_size = 0
self._guidance_scale = guidance_scale
i2v_extra_kwrags = {}
prefix_video = None
predix_video_latent_length = 0
if image:
frame_width, frame_height = image.size
scale = min(height / frame_height, width / frame_width)
height = (int(frame_height * scale) // 16) * 16
width = (int(frame_width * scale) // 16) * 16
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size)
latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8
prompt_embeds = self.text_encoder([prompt], self.device)
prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
if self.do_classifier_free_guidance:
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
init_timesteps = self.scheduler.timesteps
fps_embeds = [fps] * prompt_embeds[0].shape[0]
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
transformer_dtype = self.dtype
# with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad():
if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
# short video generation
latent_shape = [16, latent_length, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
latents = [latents]
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
)
sample_schedulers = []
for _ in range(latent_length):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * latent_length
if callback != None:
callback(-1, None, True)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
+ torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
# "causal_block_size" : causal_block_size,
"callback" : callback,
"pipeline" : self
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=prompt_embeds,
context2=negative_prompt_embeds,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=negative_prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0], False)
x0 = latents[0].unsqueeze(0)
videos = self.vae.decode(x0, tile_size= VAE_tile_size)
videos = (videos / 2 + 0.5).clamp(0, 1)
videos = [video for video in videos]
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
return videos
else:
# long video generation
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
overlap_history_frames = (overlap_history - 1) // 4 + 1
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
print(f"n_iter:{n_iter}")
output_video = None
for i in range(n_iter):
if output_video is not None: # i !=0
prefix_video = output_video[:, -overlap_history:].to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
else: # i == 0
base_num_frames_iter = base_num_frames
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
latents = [latents]
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
base_num_frames_iter,
ar_step,
predix_video_latent_length,
causal_block_size,
)
sample_schedulers = []
for _ in range(base_num_frames_iter):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * base_num_frames_iter
if callback != None:
callback(-1, None, True)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=prompt_embeds,
context2=negative_prompt_embeds,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=negative_prompt_embeds,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0].squeeze(0), False)
x0 = latents[0].unsqueeze(0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
if output_video is None:
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
else:
output_video = torch.cat(
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
) # c, f, h, w
return output_video

View File

@ -12,29 +12,17 @@ from diffusers.models.modeling_utils import ModelMixin
import numpy as np
from typing import Union,Optional
from mmgp import offload
from mmgp.offload import get_cache, clear_caches
from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target
from ..animate.motion_encoder import Generator
from ..animate.face_blocks import FaceAdapter, FaceEncoder
from ..animate.model_animate import after_patch_embedding
__all__ = ['WanModel']
def get_cache(cache_name):
all_cache = offload.shared_state.get("_cache", None)
if all_cache is None:
all_cache = {}
offload.shared_state["_cache"]= all_cache
cache = offload.shared_state.get(cache_name, None)
if cache is None:
cache = {}
offload.shared_state[cache_name] = cache
return cache
def clear_caches():
all_cache = offload.shared_state.get("_cache", None)
if all_cache is not None:
all_cache.clear()
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
@ -514,6 +502,7 @@ class WanAttentionBlock(nn.Module):
multitalk_masks=None,
ref_images_count=0,
standin_phase=-1,
motion_vec = None,
):
r"""
Args:
@ -579,19 +568,23 @@ class WanAttentionBlock(nn.Module):
y = self.norm_x(x)
y = y.to(attention_dtype)
if ref_images_count == 0:
x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
ylist= [y]
del y
x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
else:
y_shape = y.shape
y = y.reshape(y_shape[0], grid_sizes[0], -1)
y = y[:, ref_images_count:]
y = y.reshape(y_shape[0], -1, y_shape[-1])
grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
ylist= [y]
y = None
y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
x = x.reshape(y_shape[0], grid_sizes[0], -1)
x[:, ref_images_count:] += y
x = x.reshape(y_shape[0], -1, y_shape[-1])
del y
del y
y = self.norm2(x)
@ -627,6 +620,10 @@ class WanAttentionBlock(nn.Module):
x.add_(hint)
else:
x.add_(hint, alpha= scale)
if motion_vec is not None and self.block_no % 5 == 0:
x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
return x
class AudioProjModel(ModelMixin, ConfigMixin):
@ -909,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin):
norm_input_visual=True,
norm_output_audio=True,
standin= False,
motion_encoder_dim=0,
):
super().__init__()
@ -933,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin):
self.flag_causal_attention = False
self.block_mask = None
self.inject_sample_info = inject_sample_info
self.motion_encoder_dim = motion_encoder_dim
self.norm_output_audio = norm_output_audio
self.audio_window = audio_window
self.intermediate_dim = intermediate_dim
self.vae_scale = vae_scale
multitalk = multitalk_output_dim > 0
self.multitalk = multitalk
self.multitalk = multitalk
animate = motion_encoder_dim > 0
# embeddings
self.patch_embedding = nn.Conv3d(
@ -1038,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin):
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
if animate:
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype
@ -1202,7 +1220,8 @@ class WanModel(ModelMixin, ConfigMixin):
y=None,
freqs = None,
pipeline = None,
current_step = 0,
current_step_no = 0,
real_step_no = 0,
x_id= 0,
max_steps = 0,
slg_layers=None,
@ -1219,6 +1238,9 @@ class WanModel(ModelMixin, ConfigMixin):
ref_images_count = 0,
standin_freqs = None,
standin_ref = None,
pose_latents=None,
face_pixel_values=None,
):
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
@ -1251,9 +1273,18 @@ class WanModel(ModelMixin, ConfigMixin):
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1)
# embeddings
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:]
x_list[i] = x
y = None
motion_vec_list = []
for i, x in enumerate(x_list):
# animate embeddings
motion_vec = None
if pose_latents is not None:
x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
x_og_shape = x.shape
@ -1261,7 +1292,7 @@ class WanModel(ModelMixin, ConfigMixin):
else:
x = x.flatten(2).transpose(1, 2)
x_list[i] = x
x, y = None, None
x = None
block_mask = None
@ -1280,9 +1311,9 @@ class WanModel(ModelMixin, ConfigMixin):
del causal_mask
offload.shared_state["embed_sizes"] = grid_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["step_no"] = real_step_no
offload.shared_state["max_steps"] = max_steps
if current_step == 0 and x_id == 0: clear_caches()
if current_step_no == 0 and x_id == 0: clear_caches()
# arguments
kwargs = dict(
@ -1306,7 +1337,7 @@ class WanModel(ModelMixin, ConfigMixin):
if standin_ref is not None:
standin_cache_enabled = False
kwargs["standin_phase"] = 2
if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
if current_step_no == 0 or not standin_cache_enabled :
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
@ -1371,7 +1402,7 @@ class WanModel(ModelMixin, ConfigMixin):
skips_steps_cache = self.cache
if skips_steps_cache != None:
if skips_steps_cache.cache_type == "mag":
if current_step <= skips_steps_cache.start_step:
if real_step_no <= skips_steps_cache.start_step:
should_calc = True
elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all
assert len(x_list) == 1
@ -1380,7 +1411,7 @@ class WanModel(ModelMixin, ConfigMixin):
x_should_calc = []
for i in range(1 if skips_steps_cache.one_for_all else len(x_list)):
cur_x_id = i if joint_pass else x_id
cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list
cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list
skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step
skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1
cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps
@ -1400,7 +1431,7 @@ class WanModel(ModelMixin, ConfigMixin):
if x_id != 0:
should_calc = skips_steps_cache.should_calc
else:
if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1:
if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1:
should_calc = True
skips_steps_cache.accumulated_rel_l1_distance = 0
else:
@ -1453,7 +1484,7 @@ class WanModel(ModelMixin, ConfigMixin):
return [None] * len(x_list)
if standin_x is not None:
if not standin_cache_enabled and x_id ==0 : get_cache("standin").clear()
if not standin_cache_enabled: get_cache("standin").clear()
standin_x = block(standin_x, context = None, grid_sizes = None, e= standin_e0, freqs = standin_freqs, standin_phase = 1)
if slg_layers is not None and block_idx in slg_layers:
@ -1461,9 +1492,9 @@ class WanModel(ModelMixin, ConfigMixin):
continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else:
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
if should_calc:
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
del x
context = hints = audio_embedding = None

View File

@ -221,13 +221,16 @@ class SingleStreamAttention(nn.Module):
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
N_t, N_h, N_w = shape
x = xlist[0]
xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
del x
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
@ -247,9 +250,6 @@ class SingleStreamAttention(nn.Module):
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
attn_bias = None
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)
@ -302,7 +302,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(self,
x: torch.Tensor,
xlist: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None,
@ -310,14 +310,17 @@ class SingleStreamMutiAttention(SingleStreamAttention):
encoder_hidden_states = encoder_hidden_states.squeeze(0)
if x_ref_attn_map == None:
return super().forward(x, encoder_hidden_states, shape)
return super().forward(xlist, encoder_hidden_states, shape)
N_t, _, _ = shape
x = xlist[0]
xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
del x
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
@ -339,7 +342,9 @@ class SingleStreamMutiAttention(SingleStreamAttention):
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
qlist = [q]
del q
q = self.rope_1d(qlist, normalized_pos, "q")
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
_, N_a, _ = encoder_hidden_states.shape
@ -347,7 +352,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
encoder_k, encoder_v = encoder_kv.unbind(0)
del encoder_kv
if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k)
@ -356,13 +361,14 @@ class SingleStreamMutiAttention(SingleStreamAttention):
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
enclist = [encoder_k]
del encoder_k
encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k")
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)

View File

@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160
audio_emb = audio_emb.cpu().detach()
return audio_emb
def extract_audio_from_video(filename, sample_rate):
raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav'
ffmpeg_command = [
"ffmpeg",
"-y",
"-i",
str(filename),
"-vn",
"-acodec",
"pcm_s16le",
"-ar",
"16000",
"-ac",
"2",
str(raw_audio_path),
]
subprocess.run(ffmpeg_command, check=True)
human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate)
human_speech_array = loudness_norm(human_speech_array, sr)
os.remove(raw_audio_path)
return human_speech_array
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
ext = os.path.splitext(audio_path)[1].lower()
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
@ -191,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2):
return s1, s2, save_path_sum
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0):
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False):
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
pad = int(padded_frames_for_embeddings/ fps * sr)
new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
full_audio_embs = []
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
if audio_guide2 == None and not duration_changed: sum_human_speechs = None
if return_sum_only:
full_audio_embs = None
else:
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
full_audio_embs = []
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
if audio_guide2 == None and not duration_changed: sum_human_speechs = None
return full_audio_embs, sum_human_speechs

View File

@ -16,7 +16,7 @@ import torchvision
import binascii
import os.path as osp
from skimage import color
from mmgp.offload import get_cache, clear_caches
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
ASPECT_RATIO_627 = {
@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
# @torch.compile
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None):
dtype = visual_q.dtype
ref_k = ref_k.to(dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
visual_q_shape = visual_q.shape
visual_q = visual_q.view(-1, visual_q_shape[-1] )
number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2
chunk_size = int(visual_q_shape[-2] / number_chunks)
chunks =torch.split(visual_q, chunk_size)
maps_lists = [ [] for _ in ref_target_masks]
for q_chunk in chunks:
attn = q_chunk @ ref_k.transpose(-2, -1)
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
del attn
ref_target_masks = ref_target_masks.to(dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
maps_lists[class_idx].append(x_ref_attnmap)
del x_ref_attn_map_source
x_ref_attn_maps = []
for class_idx, maps_list in enumerate(maps_lists):
attn_map_fuse = torch.concat(maps_list, dim= -1)
attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1)
x_ref_attn_maps.append( attn_map_fuse )
return torch.concat(x_ref_attn_maps, dim=0)
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count):
dtype = visual_q.dtype
ref_k = ref_k.to(dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
attn = visual_q @ ref_k.transpose(-2, -1)
if attn_bias is not None: attn += attn_bias
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
del attn
x_ref_attn_maps = []
ref_target_masks = ref_target_masks.to(visual_q.dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
ref_target_masks = ref_target_masks.to(dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
if mode == 'mean':
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
elif mode == 'max':
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads)
x_ref_attn_maps.append(x_ref_attnmap)
del attn
del x_ref_attn_map_source
return torch.concat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
"""Args:
query (torch.tensor): B M H K
@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
if x_seqlens <= 1508:
split_num = 10 # 540p
else:
split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p
ref_k = ref_k[:, :x_seqlens]
if ref_images_count > 0 :
visual_q_shape = visual_q.shape
@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
if split_chunk == 1:
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
else:
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
x_ref_attn_maps /= split_num
return x_ref_attn_maps
@ -158,7 +196,6 @@ class RotaryPositionalEmbedding1D(nn.Module):
self.base = 10000
@lru_cache(maxsize=32)
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
@ -167,7 +204,7 @@ class RotaryPositionalEmbedding1D(nn.Module):
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
def forward(self, qlist, pos_indices, cache_entry = None):
"""1D RoPE.
Args:
@ -176,16 +213,26 @@ class RotaryPositionalEmbedding1D(nn.Module):
Returns:
query with the same shape as input.
"""
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
xq= qlist[0]
qlist.clear()
cache = get_cache("multitalk_rope")
freqs_cis= cache.get(cache_entry, None)
if freqs_cis is None:
freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices)
cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0)
# cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
# real * cos - imag * sin
# imag * cos + real * sin
xq_dtype = xq.dtype
xq_out = xq.to(torch.float)
xq = None
xq_rot = rotate_half(xq_out)
xq_out *= cos
xq_rot *= sin
xq_out += xq_rot
del xq_rot
xq_out = xq_out.to(xq_dtype)
return xq_out

View File

@ -1,698 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
from mmgp import offload
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
def optimized_scale(positive_flat, negative_flat):
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
rank=0,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
):
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
self.adapt_vace_model()
self.scheduler = FlowUniPCMultistepScheduler()
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
else:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
else:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.vae_stride[0])
height = 2 * (int(height) // (self.vae_stride[1] * 2))
width = 2 * (int(width) // (self.vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, self.vae_stride[1], width, self.vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
self.vae_stride[1] * self.vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
image_sizes = []
trim_video = len(keep_frames)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_frames):
if not keep:
src_video[i][:, k:k+1] = 0
src_mask[i][:, k:k+1] = 1
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def generate_timestep_matrix(
self,
num_frames,
step_template,
base_num_frames,
ar_step=5,
num_pre_ready=0,
casual_block_size=1,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while torch.all(pre_row >= (num_iterations - 1)) == False:
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (
num_iterations - 1
): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append(
(new_row != pre_row) & (new_row != num_iterations)
) # False: no need to update True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
return step_matrix, step_index, step_update_mask, valid_interval
def generate(self,
input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
source_video=None,
target_camera=None,
context_scale=1.0,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = None,
VAE_tile_size = 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
num_frames = frame_num
addnoise_condition = 20
causal_attention = True
fps = 16
ar_step = 5
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if target_camera != None:
size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device)
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
del source_video
# Process target camera (recammaster)
from wan.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None:
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
else:
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1])
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
# evaluation mode
# if sample_solver == 'unipc':
# sample_scheduler = FlowUniPCMultistepScheduler(
# num_train_timesteps=self.num_train_timesteps,
# shift=1,
# use_dynamic_shifting=False)
# sample_scheduler.set_timesteps(
# sampling_steps, device=self.device, shift=shift)
# timesteps = sample_scheduler.timesteps
# elif sample_solver == 'dpm++':
# sample_scheduler = FlowDPMSolverMultistepScheduler(
# num_train_timesteps=self.num_train_timesteps,
# shift=1,
# use_dynamic_shifting=False)
# sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
# timesteps, _ = retrieve_timesteps(
# sample_scheduler,
# device=self.device,
# sigmas=sampling_sigmas)
# else:
# raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
del noise
batch_size =len(latents)
if target_camera != None:
shape = list(latents[0].shape[1:])
shape[0] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
# arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
# arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
# arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
i2v_extra_kwrags = {}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
i2v_extra_kwrags.update(recam_dict)
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
i2v_extra_kwrags.update(vace_dict)
latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8
if ar_step == 0:
causal_block_size = 1
fps_embeds = [fps] #* prompt_embeds[0].shape[0]
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
init_timesteps = self.scheduler.timesteps
base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
prefix_video = None
predix_video_latent_length = 0
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
base_num_frames_iter,
ar_step,
predix_video_latent_length,
causal_block_size,
)
sample_schedulers = []
for _ in range(base_num_frames_iter):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * base_num_frames_iter
updated_num_steps= len(step_matrix)
if callback != None:
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
# if callback != None:
# callback(-1, None, True)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self,
"current_step" : i,
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=context,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=context,
context2=context_null,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=context,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=context_null,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=seed_g,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0].squeeze(0), False)
# for i, t in enumerate(tqdm(timesteps)):
# if target_camera != None:
# latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
# else:
# latent_model_input = latents
# slg_layers_local = None
# if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
# slg_layers_local = slg_layers
# timestep = [t]
# offload.set_step_no_for_lora(self.model, i)
# timestep = torch.stack(timestep)
# if joint_pass:
# noise_pred_cond, noise_pred_uncond = self.model(
# latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
# if self._interrupt:
# return None
# else:
# noise_pred_cond = self.model(
# latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
# if self._interrupt:
# return None
# noise_pred_uncond = self.model(
# latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
# if self._interrupt:
# return None
# # del latent_model_input
# # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
# noise_pred_text = noise_pred_cond
# if cfg_star_switch:
# positive_flat = noise_pred_text.view(batch_size, -1)
# negative_flat = noise_pred_uncond.view(batch_size, -1)
# alpha = optimized_scale(positive_flat,negative_flat)
# alpha = alpha.view(batch_size, 1, 1, 1)
# if (i <= cfg_zero_step):
# noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
# else:
# noise_pred_uncond *= alpha
# noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# temp_x0 = sample_scheduler.step(
# noise_pred[:, :target_shape[1]].unsqueeze(0),
# t,
# latents[0].unsqueeze(0),
# return_dict=False,
# generator=seed_g)[0]
# latents = [temp_x0.squeeze(0)]
# del temp_x0
# if callback is not None:
# callback(i, latents[0], False)
x0 = latents
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del latents
del sample_scheduler
return videos[0] if self.rank == 0 else None
def adapt_vace_model(self):
model = self.model
modules_dict= { k: m for k, m in model.named_modules()}
for model_layer, vace_layer in model.vace_layers_mapping.items():
module = modules_dict[f"vace_blocks.{vace_layer}"]
target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "vace", module )
delattr(model, "vace_blocks")

View File

@ -1,8 +1,12 @@
import torch
import numpy as np
import gradio as gr
def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
def text_oneframe_overlap(base_model_type):
return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@ -13,6 +17,8 @@ def test_multitalk(base_model_type):
def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
def test_wan_5B(base_model_type):
return base_model_type in ["ti2v_2_2", "lucy_edit"]
class family_handler():
@staticmethod
@ -32,7 +38,7 @@ class family_handler():
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
elif base_model_type in ["i2v_2_2"]:
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
elif base_model_type in ["ti2v_2_2"]:
elif test_wan_5B(base_model_type):
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
else: # i2v
@ -79,11 +85,13 @@ class family_handler():
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class
if test_multitalk(base_model_type):
if base_model_type in ["animate"]:
fps = 30
elif test_multitalk(base_model_type):
fps = 25
elif base_model_type in ["fantasy"]:
fps = 23
elif base_model_type in ["ti2v_2_2"]:
elif test_wan_5B(base_model_type):
fps = 24
else:
fps = 16
@ -96,14 +104,14 @@ class family_handler():
extra_model_def.update({
"frames_minimum" : frames_minimum,
"frames_steps" : frames_steps,
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2",
"multiple_submodels" : multiple_submodels,
"guidance_max_phases" : 3,
"skip_layer_guidance" : True,
"cfg_zero" : True,
"cfg_star" : True,
"adaptive_projected_guidance" : True,
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"sample_solvers":[
@ -112,9 +120,169 @@ class family_handler():
("dpm++", "dpm++"),
("flowmatch causvid", "causvid"), ]
})
if base_model_type in ["t2v"]:
extra_model_def["guide_custom_choices"] = {
"choices":[("Use Text Prompt Only", ""),
("Video to Video guided by Text Prompt", "GUV"),
("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")],
"default": "",
"letters_filter": "GUVA",
"label": "Video to Video"
}
extra_model_def["mask_preprocessing"] = {
"selection":[ "", "A"],
"visible": False
}
if base_model_type in ["infinitetalk"]:
extra_model_def["no_background_removal"] = True
extra_model_def["all_image_refs_are_background_ref"] = True
extra_model_def["guide_custom_choices"] = {
"choices":[
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
],
"default": "KI",
"letters_filter": "RGUVQKI",
"label": "Video to Video",
"show_label" : False,
}
# extra_model_def["at_least_one_image_ref_needed"] = True
if base_model_type in ["lucy_edit"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["guide_preprocessing"] = {
"selection": ["UV"],
"labels" : { "UV": "Control Video"},
"visible": False,
}
if base_model_type in ["animate"]:
extra_model_def["guide_custom_choices"] = {
"choices":[
("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"),
("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"),
("Replace Person in Control Video by Person in Reference Image", "PVBAI"),
("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"),
],
"default": "PVBKI",
"letters_filter": "PVBXAKI1",
"label": "Type of Process",
"show_label" : False,
}
extra_model_def["mask_preprocessing"] = {
"selection":[ "", "A", "XA"],
"visible": False
}
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
extra_model_def["background_ref_outpainted"] = False
extra_model_def["return_image_refs_tensor"] = True
extra_model_def["guide_inpaint_color"] = 0
if vace_class:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
"labels" : { "V": "Use Vace raw format"}
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
}
extra_model_def["image_ref_choices"] = {
"choices": [("None", ""),
("People / Objects", "I"),
("Landscape followed by People / Objects (if any)", "KI"),
("Positioned Frames followed by People / Objects (if any)", "FI"),
],
"letters_filter": "KFI",
}
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["pad_guide_video"] = True
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["return_image_refs_tensor"] = True
if base_model_type in ["standin"]:
extra_model_def["fit_into_canvas_image_refs"] = 0
extra_model_def["image_ref_choices"] = {
"choices": [
("No Reference Image", ""),
("Reference Image is a Person Face", "I"),
],
"letters_filter":"I",
}
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Reference Image", "I")],
"letters_filter":"I",
"visible": False,
}
if base_model_type in ["recam_1.3B"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["model_modes"] = {
"choices": [
("Pan Right", 1),
("Pan Left", 2),
("Tilt Up", 3),
("Tilt Down", 4),
("Zoom In", 5),
("Zoom Out", 6),
("Translate Up (with rotation)", 7),
("Translate Down (with rotation)", 8),
("Arc Left (with rotation)", 9),
("Arc Right (with rotation)", 10),
],
"default": 1,
"label" : "Camera Movement Type"
}
extra_model_def["guide_preprocessing"] = {
"selection": ["UV"],
"labels" : { "UV": "Control Video"},
"visible" : False,
}
if vace_class or base_model_type in ["animate"]:
image_prompt_types_allowed = "TVL"
elif base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TSVL"
elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL"
elif base_model_type in ["lucy_edit"]:
image_prompt_types_allowed = "TVL"
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
image_prompt_types_allowed = "SVL"
elif i2v:
image_prompt_types_allowed = "SEVL"
else:
image_prompt_types_allowed = ""
extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed
if text_oneframe_overlap(base_model_type):
extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1}
# if base_model_type in ["phantom_1.3B", "phantom_14B"]:
# extra_model_def["one_image_ref_needed"] = True
return extra_model_def
@ -122,8 +290,8 @@ class family_handler():
def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
"recam_1.3B", "animate",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod
@ -153,11 +321,12 @@ class family_handler():
@staticmethod
def get_vae_block_size(base_model_type):
return 32 if base_model_type == "ti2v_2_2" else 16
return 32 if test_wan_5B(base_model_type) else 16
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@ -171,7 +340,7 @@ class family_handler():
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
}]
if base_model_type == "ti2v_2_2":
if test_wan_5B(base_model_type):
download_def += [ {
"repoId" : "DeepBeepMeep/Wan2.2",
"sourceFolderList" : [""],
@ -182,7 +351,7 @@ class family_handler():
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
from .configs import WAN_CONFIGS
if test_class_i2v(base_model_type):
@ -195,6 +364,7 @@ class family_handler():
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
submodel_no_list = submodel_no_list,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
@ -235,15 +405,42 @@ class family_handler():
if "I" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("KI", "QKI")
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.28:
if base_model_type in "infinitetalk":
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if "U" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("U", "RU")
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.31:
if base_model_type in "recam_1.3B":
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if not "V" in video_prompt_type:
video_prompt_type += "UV"
ui_defaults["video_prompt_type"] = video_prompt_type
ui_defaults["image_prompt_type"] = ""
if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1
if settings_version < 2.32:
image_prompt_type = ui_defaults.get("image_prompt_type", "")
if test_class_i2v(base_model_type) and len(image_prompt_type) == 0:
ui_defaults["image_prompt_type"] = "S"
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"sample_solver": "unipc",
})
if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
ui_defaults["image_prompt_type"] = "S"
if base_model_type in ["fantasy"]:
ui_defaults.update({
"audio_guidance_scale": 5.0,
"sliding_window_size": 1,
"sliding_window_overlap" : 1,
})
elif base_model_type in ["multitalk"]:
@ -260,6 +457,7 @@ class family_handler():
"guidance_scale": 5.0,
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
"sliding_window_size": 81,
"sample_solver" : "euler",
"video_prompt_type": "QKI",
"remove_background_images_ref" : 0,
@ -293,6 +491,21 @@ class family_handler():
"image_prompt_type": "T",
})
if base_model_type in ["recam_1.3B", "lucy_edit"]:
ui_defaults.update({
"video_prompt_type": "UV",
})
elif base_model_type in ["animate"]:
ui_defaults.update({
"video_prompt_type": "PVBKI",
"mask_expand": 20,
"audio_prompt_type": "R",
})
if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1
ui_defaults["color_correction_strength"]= 0
if test_multitalk(base_model_type):
ui_defaults["audio_guidance_scale"] = 4
@ -309,3 +522,11 @@ class family_handler():
if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None:
video_prompt_type = video_prompt_type.replace("I", "").replace("K","")
inputs["video_prompt_type"] = video_prompt_type
if base_model_type in ["vace_standin_14B"]:
image_refs = inputs["image_refs"]
video_prompt_type = inputs["video_prompt_type"]
if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type:
gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.")

View File

@ -0,0 +1,69 @@
from pathlib import Path
import os, tempfile
import numpy as np
import soundfile as sf
import librosa
import torch
import gc
from audio_separator.separator import Separator
def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str:
"""
If the source audio is shorter than `min_seconds`, pad with trailing silence
in a temporary file, then run separation and save only the vocals to dst_path.
Returns the full path to the vocals file.
"""
default_device = torch.get_default_device()
torch.set_default_device('cpu')
dst = Path(dst_path)
dst.parent.mkdir(parents=True, exist_ok=True)
# Quick duration check
duration = librosa.get_duration(path=src_path)
use_path = src_path
temp_path = None
try:
if duration < min_seconds:
# Load (resample) and pad in memory
y, sr = librosa.load(src_path, sr=None, mono=False)
if y.ndim == 1: # ensure shape (channels, samples)
y = y[np.newaxis, :]
target_len = int(min_seconds * sr)
pad = max(0, target_len - y.shape[1])
if pad:
y = np.pad(y, ((0, 0), (0, pad)), mode="constant")
# Write a temp WAV for the separator
fd, temp_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(temp_path, y.T, sr) # soundfile expects (frames, channels)
use_path = temp_path
# Run separation: emit only the vocals, with your exact filename
sep = Separator(
output_dir=str(dst.parent),
output_format=(dst.suffix.lstrip(".") or "wav"),
output_single_stem="Vocals",
model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt"
)
sep.load_model()
out_files = sep.separate(use_path, {"Vocals": dst.stem})
out = Path(out_files[0])
return str(out if out.is_absolute() else (dst.parent / out))
finally:
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
torch.cuda.empty_cache()
gc.collect()
torch.set_default_device(default_device)
# Example:
# final = extract_vocals("in/clip.mp3", "out/vocals.wav")
# print(final)

View File

@ -7,7 +7,6 @@ import psutil
# import ffmpeg
import imageio
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
@ -22,6 +21,7 @@ from .utils.get_default_model import get_matanyone_model
from .matanyone.inference.inference_core import InferenceCore
from .matanyone_wrapper import matanyone
from shared.utils.audio_video import save_video, save_image
from mmgp import offload
arg_device = "cuda"
arg_sam_model_type="vit_h"
@ -33,6 +33,8 @@ model_in_GPU = False
matanyone_in_GPU = False
bfloat16_supported = False
# SAM generator
import copy
class MaskGenerator():
def __init__(self, sam_checkpoint, device):
global args_device
@ -89,6 +91,7 @@ def get_frames_from_image(image_input, image_state):
"last_frame_numer": 0,
"fps": None
}
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
set_image_encoder_patch()
select_SAM()
@ -537,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive
file_name = ".".join(file_name.split(".")[:-1])
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
source_audio_tracks, audio_metadata = extract_audio_tracks(video_input)
source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel )
output_fg_path = f"./mask_outputs/{file_name}_fg.mp4"
output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4"
if len(source_audio_tracks) == 0:
@ -677,7 +680,6 @@ def load_unload_models(selected):
}
# os.path.join('.')
from mmgp import offload
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None
@ -695,7 +697,8 @@ def load_unload_models(selected):
model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device)
model_in_GPU = True
from .matanyone.model.matanyone import MatAnyone
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
# matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
matanyone_model = MatAnyone.from_pretrained("ckpts/mask")
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
# offload.profile(pipe)
matanyone_model = matanyone_model.to("cpu").eval()
@ -717,27 +720,33 @@ def load_unload_models(selected):
def get_vmc_event_handler():
return load_unload_models
def export_to_vace_video_input(foreground_video_output):
gr.Info("Masked Video Input transferred to Vace For Inpainting")
return "V#" + str(time.time()), foreground_video_output
def export_image(image_refs, image_output):
gr.Info("Masked Image transferred to Current Video")
def export_image(state, image_output):
ui_settings = get_current_model_settings(state)
image_refs = ui_settings["image_refs"]
if image_refs == None:
image_refs =[]
image_refs.append( image_output)
return image_refs
ui_settings["image_refs"] = image_refs
gr.Info("Masked Image transferred to Current Image Generator")
return time.time()
def export_image_mask(image_input, image_mask):
gr.Info("Input Image & Mask transferred to Current Video")
return Image.fromarray(image_input), image_mask
def export_image_mask(state, image_input, image_mask):
ui_settings = get_current_model_settings(state)
ui_settings["image_guide"] = Image.fromarray(image_input)
ui_settings["image_mask"] = image_mask
gr.Info("Input Image & Mask transferred to Current Image Generator")
return time.time()
def export_to_current_video_engine( foreground_video_output, alpha_video_output):
def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
ui_settings = get_current_model_settings(state)
ui_settings["video_guide"] = foreground_video_output
ui_settings["video_mask"] = alpha_video_output
gr.Info("Original Video and Full Mask have been transferred")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
return foreground_video_output, alpha_video_output
return time.time()
def teleport_to_video_tab(tab_state):
@ -746,15 +755,29 @@ def teleport_to_video_tab(tab_state):
return gr.Tabs(selected="video_gen")
def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
global image_output_codec, video_output_codec
global image_output_codec, video_output_codec, get_current_model_settings
get_current_model_settings = get_current_model_settings_fn
image_output_codec = server_config.get("image_output_codec", None)
video_output_codec = server_config.get("video_output_codec", None)
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
click_brush_js = """
() => {
setTimeout(() => {
const brushButton = document.querySelector('button[aria-label="Brush"]');
if (brushButton) {
brushButton.click();
console.log('Brush button clicked');
} else {
console.log('Brush button not found');
}
}, 1000);
} """
# download assets
gr.Markdown("<B>Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep</B>")
@ -871,7 +894,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
with gr.Row():
@ -892,7 +915,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
with gr.Row(visible= True):
export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger,
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
@ -1089,10 +1112,10 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
# with gr.Column(scale=2, visible= True):
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]).then(fn=None, inputs=None, outputs=None, js=click_brush_js)
# first step: get the image information
extract_frames_button.click(
@ -1148,5 +1171,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
)
nada = gr.State({})
# clear input
gr.on(
triggers=[image_input.clear], #image_input.change,
fn=restart,
inputs=[],
outputs=[
image_state,
interactive_state,
click_state,
foreground_image_output, alpha_image_output,
template_frame,
image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
],
queue=False,
show_progress=False)

View File

@ -2,7 +2,6 @@ import math
import torch
from typing import Optional, Union, Tuple
# @torch.jit.script
def get_similarity(mk: torch.Tensor,
ms: torch.Tensor,
@ -59,6 +58,7 @@ def get_similarity(mk: torch.Tensor,
del two_ab
# similarity = (-a_sq + two_ab)
similarity =similarity.float()
if ms is not None:
similarity *= ms
similarity /= math.sqrt(CK)

View File

@ -73,5 +73,5 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
if ti > (n_warmup-1):
frames.append((com_np*255).astype(np.uint8))
phas.append((pha*255).astype(np.uint8))
# phas.append(np.clip(pha * 255, 0, 255).astype(np.uint8))
return frames, phas

View File

@ -100,7 +100,7 @@ class OptimizedPyannote31SpeakerSeparator:
self.hf_token = hf_token
self._overlap_pipeline = None
def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]:
def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]:
"""Optimized main separation function with memory management."""
xprint("Starting optimized audio separation...")
self._current_audio_path = os.path.abspath(audio_path)
@ -128,7 +128,11 @@ class OptimizedPyannote31SpeakerSeparator:
gc.collect()
# Save outputs efficiently
output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2)
if audio_original_path is None:
waveform_original = waveform
else:
waveform_original, sample_rate = self.load_audio(audio_original_path)
output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2)
return output_paths
@ -835,7 +839,7 @@ class OptimizedPyannote31SpeakerSeparator:
for turn, _, speaker in diarization.itertracks(yield_label=True):
xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s")
def extract_dual_audio(audio, output1, output2, verbose = False):
def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None):
global verbose_output
verbose_output = verbose
separator = OptimizedPyannote31SpeakerSeparator(
@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False):
import time
start_time = time.time()
outputs = separator.separate_audio(audio, output1, output2)
outputs = separator.separate_audio(audio, output1, output2, audio_original)
elapsed_time = time.time() - start_time
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")

View File

@ -20,14 +20,16 @@ soundfile
mutagen
pyloudnorm
librosa==0.11.0
speechbrain==1.0.3
audio-separator==0.36.1
# UI & interaction
gradio==5.23.0
gradio==5.29.0
dashscope
loguru
# Vision & segmentation
opencv-python>=4.9.0.80
opencv-python>=4.12.0.88
segment-anything
rembg[gpu]==2.0.65
onnxruntime-gpu
@ -43,14 +45,14 @@ pydantic==2.10.6
# Math & modeling
torchdiffeq>=0.2.5
tensordict>=0.6.1
mmgp==3.5.10
peft==0.17.0
mmgp==3.6.0
peft==0.15.0
matplotlib
# Utilities
ftfy
piexif
pynvml
nvidia-ml-py
misaki
# Optional / commented out

View File

@ -1,6 +1,6 @@
# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py)
def get_rgb_factors(model_family, model_type = None):
if model_family == "wan":
if model_family in ["wan", "qwen"]:
if model_type =="ti2v_2_2":
latent_channels = 48
latent_dimensions = 3
@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None):
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
else:
latent_rgb_factors_bias = latent_rgb_factors = None
return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -3,6 +3,7 @@ import torch
from importlib.metadata import version
from mmgp import offload
import torch.nn.functional as F
import warnings
major, minor = torch.cuda.get_device_capability(None)
bfloat16_supported = major >= 8
@ -42,34 +43,51 @@ except ImportError:
sageattn_varlen_wrapper = None
import warnings
try:
from sageattention import sageattn
from .sage2_core import sageattn as alt_sageattn, is_sage2_supported
from .sage2_core import sageattn as sageattn2, is_sage2_supported
sage2_supported = is_sage2_supported()
except ImportError:
sageattn = None
alt_sageattn = None
sageattn2 = None
sage2_supported = False
# @torch.compiler.disable()
def sageattn_wrapper(
@torch.compiler.disable()
def sageattn2_wrapper(
qkv_list,
attention_length
):
q,k, v = qkv_list
if True:
qkv_list = [q,k,v]
del q, k ,v
o = alt_sageattn(qkv_list, tensor_layout="NHD")
else:
o = sageattn(q, k, v, tensor_layout="NHD")
del q, k ,v
qkv_list = [q,k,v]
del q, k ,v
o = sageattn2(qkv_list, tensor_layout="NHD")
qkv_list.clear()
return o
try:
from sageattn import sageattn_blackwell as sageattn3
except ImportError:
sageattn3 = None
@torch.compiler.disable()
def sageattn3_wrapper(
qkv_list,
attention_length
):
q,k, v = qkv_list
# qkv_list = [q,k,v]
# del q, k ,v
# o = sageattn3(qkv_list, tensor_layout="NHD")
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
o = sageattn3(q, k, v)
o = o.transpose(1,2)
qkv_list.clear()
return o
# try:
# if True:
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
@ -94,7 +112,7 @@ def sageattn_wrapper(
# return o
# except ImportError:
# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
# sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable()
def sdpa_wrapper(
@ -124,21 +142,28 @@ def get_attention_modes():
ret.append("xformers")
if sageattn_varlen_wrapper != None:
ret.append("sage")
if sageattn != None and version("sageattention").startswith("2") :
if sageattn2 != None and version("sageattention").startswith("2") :
ret.append("sage2")
if sageattn3 != None: # and version("sageattention").startswith("3") :
ret.append("sage3")
return ret
def get_supported_attention_modes():
ret = get_attention_modes()
major, minor = torch.cuda.get_device_capability()
if major < 10:
if "sage3" in ret:
ret.remove("sage3")
if not sage2_supported:
if "sage2" in ret:
ret.remove("sage2")
major, minor = torch.cuda.get_device_capability()
if major < 7:
if "sage" in ret:
ret.remove("sage")
return ret
__all__ = [
@ -201,7 +226,7 @@ def pay_attention(
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
assert attention_mask == None
# Poor's man var k len attention
assert q_lens == None
@ -234,7 +259,7 @@ def pay_attention(
q_chunks, k_chunks, v_chunks = None, None, None
o = torch.cat(o, dim = 0)
return o
elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"):
elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"):
assert b == 1
szq = q_lens[0].item() if q_lens != None else lq
szk = k_lens[0].item() if k_lens != None else lk
@ -284,13 +309,19 @@ def pay_attention(
max_seqlen_q=lq,
max_seqlen_kv=lk,
).unflatten(0, (b, lq))
elif attn=="sage3":
import math
if cross_attn or True:
qkv_list = [q,k,v]
del q,k,v
x = sageattn3_wrapper(qkv_list, lq)
elif attn=="sage2":
import math
if cross_attn or True:
qkv_list = [q,k,v]
del q,k,v
x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0)
# else:
# layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"]

View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
"""
Convert a Flux model from Diffusers (folder or single-file) into the original
single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
Output : /path/to/flux1-your-model.safetensors (transformer only)
Usage:
python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
# optional quantization:
# --fp8 (float8_e4m3fn, simple)
# --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
"""
import argparse
import json
from pathlib import Path
from collections import OrderedDict
import torch
from safetensors import safe_open
import safetensors.torch
from tqdm import tqdm
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("diffusers_path", type=str,
help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
ap.add_argument("output_path", type=str,
help="Output .safetensors path for the Flux transformer.")
ap.add_argument("--fp8", action="store_true",
help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
ap.add_argument("--fp8-scaled", action="store_true",
help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
return ap.parse_args()
# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
DIFFUSERS_MAP = {
# global embeds
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
# dual-stream (image/text) blocks
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
# single-stream blocks
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
# final
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
# these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
class DiffusersSource:
"""
Uniform interface over:
1) Folder with index JSON + shards
2) Folder with exactly one .safetensors (no index)
3) Single .safetensors file
Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
"""
POSSIBLE_PREFIXES = ["", "model."] # try in this order
def __init__(self, path: Path):
p = Path(path)
if p.is_dir():
# use 'transformer' subfolder if present
if (p / "transformer").is_dir():
p = p / "transformer"
self._init_from_dir(p)
elif p.is_file() and p.suffix == ".safetensors":
self._init_from_single_file(p)
else:
raise FileNotFoundError(f"Invalid path: {p}")
# ---------- common helpers ----------
@staticmethod
def _strip_prefix(k: str) -> str:
return k[6:] if k.startswith("model.") else k
def _resolve(self, want: str):
"""
Return the actual stored key matching `want` by trying known prefixes.
"""
for pref in self.POSSIBLE_PREFIXES:
k = pref + want
if k in self._all_keys:
return k
return None
def has(self, want: str) -> bool:
return self._resolve(want) is not None
def get(self, want: str) -> torch.Tensor:
real_key = self._resolve(want)
if real_key is None:
raise KeyError(f"Missing key: {want}")
return self._get_by_real_key(real_key).to("cpu")
@property
def base_keys(self):
# keys without 'model.' prefix for scanning
return [self._strip_prefix(k) for k in self._all_keys]
# ---------- modes ----------
def _init_from_single_file(self, file_path: Path):
self._mode = "single"
self._file = file_path
self._handle = safe_open(file_path, framework="pt", device="cpu")
self._all_keys = list(self._handle.keys())
def _get_by_real_key(real_key: str):
return self._handle.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
def _init_from_dir(self, dpath: Path):
index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
if index_json.exists():
with open(index_json, "r", encoding="utf-8") as f:
index = json.load(f)
weight_map = index["weight_map"] # full mapping
self._mode = "sharded"
self._dpath = dpath
self._weight_map = {k: dpath / v for k, v in weight_map.items()}
self._all_keys = list(self._weight_map.keys())
self._open_handles = {}
def _get_by_real_key(real_key: str):
fpath = self._weight_map[real_key]
h = self._open_handles.get(fpath)
if h is None:
h = safe_open(fpath, framework="pt", device="cpu")
self._open_handles[fpath] = h
return h.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
return
# no index: try exactly one safetensors in folder
files = sorted(dpath.glob("*.safetensors"))
if len(files) != 1:
raise FileNotFoundError(
f"No index found and {dpath} does not contain exactly one .safetensors file."
)
self._init_from_single_file(files[0])
def main():
args = parse_args()
src = DiffusersSource(Path(args.diffusers_path))
# Count blocks by scanning base keys (with any 'model.' prefix removed)
num_dual = 0
num_single = 0
for k in src.base_keys:
if k.startswith("transformer_blocks."):
try:
i = int(k.split(".")[1])
num_dual = max(num_dual, i + 1)
except Exception:
pass
elif k.startswith("single_transformer_blocks."):
try:
i = int(k.split(".")[1])
num_single = max(num_single, i + 1)
except Exception:
pass
print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
# Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
shift, scale = vec.chunk(2, dim=0)
return torch.cat([scale, shift], dim=0)
orig = {}
# Per-block (dual)
for b in range(num_dual):
prefix = f"transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("double_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Per-block (single)
for b in range(num_single):
prefix = f"single_transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("single_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Globals (non-block)
for okey, dvals in DIFFUSERS_MAP.items():
if okey.startswith(("double_blocks.", "single_blocks.")):
continue
dkeys = dvals
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey] = src.get(dkeys[0])
else:
orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
if "final_layer.adaLN_modulation.1.weight" in orig:
orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.weight"]
)
if "final_layer.adaLN_modulation.1.bias" in orig:
orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.bias"]
)
# Optional FP8 variants (experimental; not required for ComfyUI/BFL)
if args.fp8 or args.fp8_scaled:
dtype = torch.float8_e4m3fn # noqa
minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
def stochastic_round_to(t):
t = t.float().clamp(minv, maxv)
lower = torch.floor(t * 256) / 256
upper = torch.ceil(t * 256) / 256
prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
rnd = torch.rand_like(t)
out = torch.where(rnd < prob, upper, lower)
return out.to(dtype)
def scale_to_8bit(weight, target_max=416.0):
absmax = weight.abs().max()
scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
scaled = (weight / scale).clamp(minv, maxv).to(dtype)
return scaled, scale
scales = {}
for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
t = orig[k]
if args.fp8:
orig[k] = stochastic_round_to(t)
else:
if k.endswith(".weight") and t.dim() == 2:
qt, s = scale_to_8bit(t)
orig[k] = qt
scales[k[:-len(".weight")] + ".scale_weight"] = s
else:
orig[k] = t.clamp(minv, maxv).to(dtype)
if args.fp8_scaled:
orig.update(scales)
orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
else:
# Default: save in bfloat16
for k in list(orig.keys()):
orig[k] = orig[k].to(torch.bfloat16).cpu()
out_path = Path(args.output_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
meta = OrderedDict()
meta["format"] = "pt"
meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
print(f"Saving transformer to: {out_path}")
safetensors.torch.save_file(orig, str(out_path), metadata=meta)
print("Done.")
if __name__ == "__main__":
main()

532
shared/gradio/gallery.py Normal file
View File

@ -0,0 +1,532 @@
from __future__ import annotations
import os, io, tempfile, mimetypes
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
import gradio as gr
import PIL
import time
from PIL import Image as PILImage
FilePath = str
ImageLike = Union["PIL.Image.Image", Any]
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
def get_state(state):
return state if isinstance(state, dict) else state.value
def get_list( objs):
if objs is None:
return []
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
def record_last_action(st, last_action):
st["last_action"] = last_action
st["last_time"] = time.time()
class AdvancedMediaGallery:
def __init__(
self,
label: str = "Media",
*,
media_mode: Literal["image", "video"] = "image",
height = None,
columns: Union[int, Tuple[int, ...]] = 6,
show_label: bool = True,
initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None,
elem_id: Optional[str] = None,
elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",),
accept_filter: bool = True, # restrict Add-button dialog to allowed extensions
single_image_mode: bool = False, # start in single-image mode (Add replaces)
):
assert media_mode in ("image", "video")
self.label = label
self.media_mode = media_mode
self.height = height
self.columns = columns
self.show_label = show_label
self.elem_id = elem_id
self.elem_classes = list(elem_classes) if elem_classes else None
self.accept_filter = accept_filter
items = self._normalize_initial(initial or [], media_mode)
# Components (filled on mount)
self.container: Optional[gr.Column] = None
self.gallery: Optional[gr.Gallery] = None
self.upload_btn: Optional[gr.UploadButton] = None
self.btn_remove: Optional[gr.Button] = None
self.btn_left: Optional[gr.Button] = None
self.btn_right: Optional[gr.Button] = None
self.btn_clear: Optional[gr.Button] = None
# Single dict state
self.state: Optional[gr.State] = None
self._initial_state: Dict[str, Any] = {
"items": items,
"selected": (len(items) - 1) if items else 0, # None,
"single": bool(single_image_mode),
"mode": self.media_mode,
"last_action": "",
}
# ---------------- helpers ----------------
def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]:
out: List[Any] = []
if mode == "image":
for it in items:
p = self._ensure_image_item(it)
if p is not None:
out.append(p)
else:
for it in items:
if isinstance(item, tuple): item = item[0]
if isinstance(it, str) and self._is_video_path(it):
out.append(os.path.abspath(it))
return out
def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]:
# Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path
if isinstance(item, tuple): item = item[0]
if isinstance(item, str):
return os.path.abspath(item) if self._is_image_path(item) else None
if PILImage is None:
return None
try:
if isinstance(item, PILImage.Image):
img = item
else:
import numpy as np # type: ignore
if isinstance(item, np.ndarray):
img = PILImage.fromarray(item)
elif hasattr(item, "read"):
data = item.read()
img = PILImage.open(io.BytesIO(data)).convert("RGBA")
else:
return None
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp.name)
return tmp.name
except Exception:
return None
@staticmethod
def _extract_path(obj: Any) -> Optional[str]:
# Try to get a filesystem path (for mode filtering); otherwise None.
if isinstance(obj, str):
return obj
try:
import pathlib
if isinstance(obj, pathlib.Path): # type: ignore
return str(obj)
except Exception:
pass
if isinstance(obj, dict):
return obj.get("path") or obj.get("name")
for attr in ("path", "name"):
if hasattr(obj, attr):
try:
val = getattr(obj, attr)
if isinstance(val, str):
return val
except Exception:
pass
return None
@staticmethod
def _is_image_path(p: str) -> bool:
ext = os.path.splitext(p)[1].lower()
if ext in IMAGE_EXTS:
return True
mt, _ = mimetypes.guess_type(p)
return bool(mt and mt.startswith("image/"))
@staticmethod
def _is_video_path(p: str) -> bool:
ext = os.path.splitext(p)[1].lower()
if ext in VIDEO_EXTS:
return True
mt, _ = mimetypes.guess_type(p)
return bool(mt and mt.startswith("video/"))
def _filter_items_by_mode(self, items: List[Any]) -> List[Any]:
# Enforce image-only or video-only collection regardless of how files were added.
out: List[Any] = []
if self.media_mode == "image":
for it in items:
p = self._extract_path(it)
if p is None:
# No path: likely an image object added programmatically => keep
out.append(it)
elif self._is_image_path(p):
out.append(os.path.abspath(p))
else:
for it in items:
p = self._extract_path(it)
if p is not None and self._is_video_path(p):
out.append(os.path.abspath(p))
return out
@staticmethod
def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]:
# Keep it simple: dedupe by path when available, else allow duplicates.
seen_paths = set()
def key(x: Any) -> Optional[str]:
if isinstance(x, str): return os.path.abspath(x)
try:
import pathlib
if isinstance(x, pathlib.Path): # type: ignore
return os.path.abspath(str(x))
except Exception:
pass
if isinstance(x, dict):
p = x.get("path") or x.get("name")
return os.path.abspath(p) if isinstance(p, str) else None
for attr in ("path", "name"):
if hasattr(x, attr):
try:
v = getattr(x, attr)
return os.path.abspath(v) if isinstance(v, str) else None
except Exception:
pass
return None
out: List[Any] = []
for lst in (cur, add):
for it in lst:
k = key(it)
if k is None or k not in seen_paths:
out.append(it)
if k is not None:
seen_paths.add(k)
return out
@staticmethod
def _paths_from_payload(payload: Any) -> List[Any]:
# Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly.
if payload is None:
return []
if isinstance(payload, (list, tuple, set)):
return list(payload)
return [payload]
# ---------------- event handlers ----------------
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
# Mirror the selected index into state and the gallery (server-side selected_index)
st = get_state(state)
last_time = st.get("last_time", None)
if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
# print(f"ignored:{time.time()}, real {st['selected']}")
return gr.update(selected_index=st["selected"]), st
idx = None
if evt is not None and hasattr(evt, "index"):
ix = evt.index
if isinstance(ix, int):
idx = ix
elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int):
if isinstance(self.columns, int) and len(ix) >= 2:
idx = ix[0] * max(1, int(self.columns)) + ix[1]
else:
idx = ix[0]
n = len(get_list(gallery))
sel = idx if (idx is not None and 0 <= idx < n) else None
# print(f"image selected evt index:{sel}/{evt.selected}")
st["selected"] = sel
return gr.update(), st
def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users upload via the Gallery itself.
# items_filtered = self._filter_items_by_mode(list(value or []))
items_filtered = list(value or [])
st = get_state(state)
new_items = self._paths_from_payload(items_filtered)
st["items"] = new_items
new_sel = len(new_items) - 1
st["selected"] = new_sel
record_last_action(st,"add")
return gr.update(selected_index=new_sel), st
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users add/drag/drop/delete via the Gallery itself.
# items_filtered = self._filter_items_by_mode(list(value or []))
items_filtered = list(value or [])
st = get_state(state)
st["items"] = items_filtered
# Keep selection if still valid, else default to last
old_sel = st.get("selected", None)
if old_sel is None or not (0 <= old_sel < len(items_filtered)):
new_sel = (len(items_filtered) - 1) if items_filtered else None
else:
new_sel = old_sel
st["selected"] = new_sel
st["last_action"] ="gallery_change"
# print(f"gallery change: set sel {new_sel}")
return gr.update(selected_index=new_sel), st
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
"""
Insert added items right AFTER the currently selected index.
Keeps the same ordering as chosen in the file picker, dedupes by path,
and re-selects the last inserted item.
"""
# New items (respect image/video mode)
# new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
new_items = self._paths_from_payload(files_payload)
st = get_state(state)
cur: List[Any] = get_list(gallery)
sel = st.get("selected", None)
if sel is None:
sel = (len(cur) -1) if len(cur)>0 else 0
single = bool(st.get("single", False))
# Nothing to add: keep as-is
if not new_items:
return gr.update(value=cur, selected_index=st.get("selected")), st
# Single-image mode: replace
if single:
st["items"] = [new_items[-1]]
st["selected"] = 0
return gr.update(value=st["items"], selected_index=0), st
# ---------- helpers ----------
def key_of(it: Any) -> Optional[str]:
# Prefer class helper if present
if hasattr(self, "_extract_path"):
p = self._extract_path(it) # type: ignore
else:
p = it if isinstance(it, str) else None
if p is None and isinstance(it, dict):
p = it.get("path") or it.get("name")
if p is None and hasattr(it, "path"):
try: p = getattr(it, "path")
except Exception: p = None
if p is None and hasattr(it, "name"):
try: p = getattr(it, "name")
except Exception: p = None
return os.path.abspath(p) if isinstance(p, str) else None
# Dedupe the incoming batch by path, preserve order
seen_new = set()
incoming: List[Any] = []
for it in new_items:
k = key_of(it)
if k is None or k not in seen_new:
incoming.append(it)
if k is not None:
seen_new.add(k)
insert_pos = min(sel, len(cur) -1)
cur_clean = cur
# Build final list and selection
merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:]
new_sel = insert_pos + len(incoming) # select the last inserted item
st["items"] = merged
st["selected"] = new_sel
record_last_action(st,"add")
# print(f"gallery add: set sel {new_sel}")
return gr.update(value=merged, selected_index=new_sel), st
def _on_remove(self, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
if sel is None or not (0 <= sel < len(items)):
return gr.update(value=items, selected_index=st.get("selected")), st
items.pop(sel)
if not items:
st["items"] = []; st["selected"] = None
return gr.update(value=[], selected_index=None), st
new_sel = min(sel, len(items) - 1)
st["items"] = items; st["selected"] = new_sel
record_last_action(st,"remove")
# print(f"gallery del: new sel {new_sel}")
return gr.update(value=items, selected_index=new_sel), st
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
if sel is None or not (0 <= sel < len(items)):
return gr.update(value=items, selected_index=sel), st
j = sel + delta
if j < 0 or j >= len(items):
return gr.update(value=items, selected_index=sel), st
items[sel], items[j] = items[j], items[sel]
st["items"] = items; st["selected"] = j
record_last_action(st,"move")
# print(f"gallery move: set sel {j}")
return gr.update(value=items, selected_index=j), st
def _on_clear(self, state: Dict[str, Any]) :
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
record_last_action(st,"clear")
# print(f"Clear all")
return gr.update(value=[], selected_index=None), st
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
st = get_state(state); st["single"] = bool(to_single)
items: List[Any] = list(st["items"]); sel = st.get("selected", None)
if st["single"]:
keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None)
items = [keep] if keep is not None else []
sel = 0 if items else None
st["items"] = items; st["selected"] = sel
upload_update = gr.update(file_count=("single" if st["single"] else "multiple"))
left_update = gr.update(visible=not st["single"])
right_update = gr.update(visible=not st["single"])
clear_update = gr.update(visible=not st["single"])
gallery_update= gr.update(value=items, selected_index=sel)
return upload_update, left_update, right_update, clear_update, gallery_update, st
# ---------------- build & wire ----------------
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
if parent is not None:
with parent:
col = self._build_ui(update_form)
else:
col = self._build_ui(update_form)
if not update_form:
self._wire_events()
return col
def _build_ui(self, update = False) -> gr.Column:
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
self.container = col
self.state = gr.State(dict(self._initial_state))
if update:
self.gallery = gr.update(
value=self._initial_state["items"],
selected_index=self._initial_state["selected"], # server-side selection
label=self.label,
show_label=self.show_label,
)
else:
self.gallery = gr.Gallery(
value=self._initial_state["items"],
label=self.label,
height=self.height,
columns=self.columns,
show_label=self.show_label,
preview= True,
# type="pil", # very slow
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
selected_index=self._initial_state["selected"], # server-side selection
)
# One-line controls
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
with gr.Row(equal_height=True, elem_classes=["amg-controls"]):
self.upload_btn = gr.UploadButton(
"Set" if self._initial_state["single"] else "Add",
file_types=exts,
file_count=("single" if self._initial_state["single"] else "multiple"),
variant="primary",
size="sm",
min_width=1,
)
self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
return col
def _wire_events(self):
# Selection: mirror into state and keep gallery.selected_index in sync
self.gallery.select(
self._on_select,
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
self.gallery.upload(
self._on_upload,
inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
self.gallery.upload(
self._on_gallery_change,
inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Add via UploadButton
self.upload_btn.upload(
self._on_add,
inputs=[self.upload_btn, self.state, self.gallery],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Remove selected
self.btn_remove.click(
self._on_remove,
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Reorder using selected index, keep same item selected
self.btn_left.click(
lambda st, gallery: self._on_move(-1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
self.btn_right.click(
lambda st, gallery: self._on_move(+1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Clear all
self.btn_clear.click(
self._on_clear,
inputs=[self.state],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# ---------------- public API ----------------
def set_one_image_mode(self, enabled: bool = True):
"""Toggle single-image mode at runtime."""
return (
self._on_toggle_single,
[gr.State(enabled), self.state],
[self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state],
)
def get_toggable_elements(self):
return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state]
# import gradio as gr
# with gr.Blocks() as demo:
# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8)
# amg.mount()
# g = amg.gallery
# # buttons to switch modes live (optional)
# def process(g):
# pass
# with gr.Row():
# gr.Button("toto").click(process, g)
# gr.Button("ONE image").click(*amg.set_one_image_mode(True))
# gr.Button("MULTI image").click(*amg.set_one_image_mode(False))
# demo.launch()

View File

View File

@ -0,0 +1,240 @@
import torch
from .utils import *
from functools import partial
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor=8):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class LanPaint():
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
self.n_steps = NSteps
self.chara_lamb = Lambda
self.IS_FLUX = IS_FLUX
self.IS_FLOW = IS_FLOW
self.step_size = StepSize
self.friction = Friction
self.chara_beta = Beta
self.img_dim_size = None
def add_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (None,) * (self.img_dim_size-1)
return array[index]
def remove_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (0,) * (self.img_dim_size-1)
return array[index]
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.img_dim_size = len(x.shape)
self.latent_image = latent_image
self.noise = noise
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
if IS_FLUX:
cfg_BIG = 1.0
def double_denoise(latents, t):
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
return predict_std, predict_big
if len(sigma.shape) == 0:
sigma = torch.tensor([sigma.item()])
latent_mask = 1 - latent_mask
if IS_FLUX or IS_FLOW:
Flow_t = sigma
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
VE_Sigma = Flow_t / (1 - Flow_t)
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
else:
VE_Sigma = sigma
abt = 1/( 1+VE_Sigma**2 )
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
# VE_Sigma, abt, Flow_t = current_times
current_times = (VE_Sigma, abt, Flow_t)
step_size = self.step_size * (1 - abt)
step_size = self.add_none_dims(step_size)
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
# This is the replace step
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
x = x * (1 - latent_mask) + noisy_image * latent_mask
if IS_FLUX or IS_FLOW:
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations Start ###############
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
args = None
for i in range(n_steps):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations End ###############
# out is x_0
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
# out = out * (1-latent_mask) + self.latent_image * latent_mask
# return out
return x
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
lamb = self.chara_lamb
if self.IS_FLUX or self.IS_FLOW:
# compute t for flow model, with a small epsilon compensating for numerical error.
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
if x_0 is None: return None
else:
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
if x_0 is None: return None
score_x = -(x_t - x_0)
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
return score_x * (1 - mask) + score_y * mask
def sigma_x(self, abt):
# the time scale for the x_t update
return abt**0
def sigma_y(self, abt):
beta = self.chara_beta * abt ** 0
return beta
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
# prepare the step size and time parameters
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
# print('mask',mask.device)
if torch.mean(dtx) <= 0.:
return x_t, args
# -------------------------------------------------------------------------
# Compute the Langevin dynamics update in variance perserving notation
# -------------------------------------------------------------------------
#x0 = self.x0_evalutation(x_t, score, sigma, args)
#C = abt**0.5 * x0 / (1-abt)
A = A_x * (1-mask) + A_y * mask
D = D_x * (1-mask) + D_y * mask
dt = dtx * (1-mask) + dty * mask
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
dtype = x_t.dtype
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
x_t, v = osc.dynamics(x_t, v, dt )
x_t = x_t.to(dtype)
v = v.to(dtype)
return x_t, v
if args is None:
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
v, C = args
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C = C_new
return x_t, (v, C)
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
# -------------------------------------------------------------------------
# Unpack current times parameters (sigma and abt)
sigma, abt, flow_t = current_times
sigma = self.add_none_dims(sigma)
abt = self.add_none_dims(abt)
# Compute time step (dtx, dty) for x and y branches.
dtx = 2 * step_size * sigma_x
dty = 2 * step_size * sigma_y
# -------------------------------------------------------------------------
# Define friction parameter Gamma_hat for each branch.
# Using dtx**0 provides a tensor of the proper device/dtype.
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
# adjust dt to match denoise-addnoise steps sizes
Gamma_hat_x /= 2.
Gamma_hat_y /= 2.
A_t_x = (1) / ( 1 - abt ) * dtx / 2
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
A_x = A_t_x / (dtx/2)
A_y = A_t_y / (dty/2)
Gamma_x = Gamma_hat_x / (dtx/2)
Gamma_y = Gamma_hat_y / (dty/2)
#D_x = (2 * (1 + sigma**2) )**0.5
#D_y = (2 * (1 + sigma**2) )**0.5
D_x = (2 * abt**0 )**0.5
D_y = (2 * abt**0 )**0.5
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
def x0_evalutation(self, x_t, score, sigma, args):
x0 = x_t + score(x_t)
return x0

301
shared/inpainting/utils.py Normal file
View File

@ -0,0 +1,301 @@
import torch
def epxm1_x(x):
# Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
result = torch.special.expm1(x) / x
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x) < 1e-2
result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
return result
def epxm1mx_x2(x):
# Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x) / x**2
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**2) < 1e-2
result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
return result
def expm1mxmhx2_x3(x):
# Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**3) < 1e-2
result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
return result
def exp_1mcosh_GD(gamma_t, delta):
"""
Compute e^(-Γt) * (1 - cosh(ΓtΔ))/ ( (Γt)**2 Δ )
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) )
numerator = torch.where(is_positive, numerator_pos, numerator_neg)
result = numerator / (delta * gamma_t**2 )
# Handle NaN/inf cases
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Handle numerical instability for small delta
mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
result = torch.where(mask, taylor, result)
return result
def exp_sinh_GsqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / (ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
denominator_pos = gamma_t_sqrt_delta
result_pos = numerator_pos / gamma_t_sqrt_delta
result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
# Taylor expansion for small gamma_t_sqrt_delta
mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
result_pos = torch.where(mask, taylor, result_pos)
# Handle negative delta
result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
result = torch.where(is_positive, result_pos, result_neg)
return result
def exp_cosh(gamma_t, delta):
"""
Compute e^(-Γt) * cosh(ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
return result
def exp_sinh_sqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / Δ
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
result = gamma_t * exp_sinh_GsqrtD_result
return result
def zeta1(gamma_t, delta):
# Compute hyperbolic terms and exponential
half_gamma_t = gamma_t / 2
exp_cosh_term = exp_cosh(half_gamma_t, delta)
exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
# Main computation
numerator = 1 - (exp_cosh_term + exp_sinh_term)
denominator = gamma_t * (1 - delta) / 4
result = 1 - numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small x (similar to your epxm1Dx approach)
mask = torch.abs(denominator) < 5e-3
term1 = epxm1_x(-gamma_t)
term2 = epxm1mx_x2(-gamma_t)
term3 = expm1mxmhx2_x3(-gamma_t)
taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
result = torch.where(mask, taylor, result)
return result
def exp_cosh_minus_terms(gamma_t, delta):
"""
Compute E^(-) * (Cosh[] - 1 - (Cosh[Δ] - 1)/Δ) / ((1 - Δ))
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_term = torch.exp(-gamma_t)
# Compute individual terms
exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
#exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
# Main computation
numerator = exp_cosh_term - exp_cosh_delta_term
denominator = gamma_t * (1 - delta)
result = numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small gamma_t and delta near 1
mask = (torch.abs(denominator) < 1e-1)
exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
taylor = (
gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0)
- denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
)
result = torch.where(mask, taylor, result)
return result
def zeta2(gamma_t, delta):
half_gamma_t = gamma_t / 2
return exp_sinh_GsqrtD(half_gamma_t, delta)
def sig11(gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def Zcoefs(gamma_t, delta):
Zeta1 = zeta1(gamma_t, delta)
Zeta2 = zeta2(gamma_t, delta)
sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
amplitude = torch.sqrt(sq_total)
Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5
#cterm = exp_cosh_minus_terms(gamma_t, delta)
#sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
#Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
def Zcoefs_asymp(gamma_t, delta):
A_t = (gamma_t * (1 - delta) )/4
return epxm1_x(- 2 * A_t)
class StochasticHarmonicOscillator:
"""
Simulates a stochastic harmonic oscillator governed by the equations:
dy(t) = q(t) dt
dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
Also define v(t) = q(t) / Γ, which is numerically more stable.
Where:
y(t) - Position variable
q(t) - Velocity variable
Γ - Damping coefficient
A - Harmonic potential strength
C - Constant force term
D - Noise amplitude
dw(t) - Wiener process (Brownian motion)
"""
def __init__(self, Gamma, A, C, D):
self.Gamma = Gamma
self.A = A
self.C = C
self.D = D
self.Delta = 1 - 4 * A / Gamma
def sig11(self, gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def sig22(self, gamma_t, delta):
return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta)
def dynamics(self, y0, v0, t):
"""
Calculates the position and velocity variables at time t.
Parameters:
y0 (float): Initial position
v0 (float): Initial velocity v(0) = q(0) / Γ
t (float): Time at which to evaluate the dynamics
Returns:
tuple: (y(t), v(t))
"""
dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
Delta = self.Delta + dummyzero
Gamma_hat = self.Gamma * t + dummyzero
A = self.A + dummyzero
C = self.C + dummyzero
D = self.D + dummyzero
Gamma = self.Gamma + dummyzero
zeta_1 = zeta1( Gamma_hat, Delta)
zeta_2 = zeta2( Gamma_hat, Delta)
EE = 1 - Gamma_hat * zeta_2
if v0 is None:
v0 = torch.randn_like(y0) * D / 2 ** 0.5
#v0 = (C - A * y0)/Gamma**0.5
# Calculate mean position and velocity
term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
y_mean = term1 + y0
v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
# sample new position and velocity with multivariate normal distribution
batch_shape = y0.shape
cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
cov_matrix[..., 0, 0] = cov_yy
cov_matrix[..., 0, 1] = cov_yv
cov_matrix[..., 1, 0] = cov_yv # symmetric
cov_matrix[..., 1, 1] = cov_vv
# Compute the Cholesky decomposition to get scale_tril
#scale_tril = torch.linalg.cholesky(cov_matrix)
scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
tol = 1e-8
cov_yy = torch.clamp( cov_yy, min = tol )
sd_yy = torch.sqrt( cov_yy )
inv_sd_yy = 1/(sd_yy)
scale_tril[..., 0, 0] = sd_yy
scale_tril[..., 0, 1] = 0.
scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
# check if it matches torch.linalg.
#assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
# Sample correlated noise from multivariate normal
mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
mean[..., 0] = y_mean
mean[..., 1] = v_mean
new_yv = torch.distributions.MultivariateNormal(
loc=mean,
scale_tril=scale_tril
).sample()
return new_yv[...,0], new_yv[...,1]

View File

@ -232,6 +232,9 @@ def save_video(tensor,
retry=5):
"""Save tensor as video with configurable codec and container options."""
if torch.is_tensor(tensor) and len(tensor.shape) == 4:
tensor = tensor.unsqueeze(0)
suffix = f'.{container}'
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
if not cache_file.endswith(suffix):

110
shared/utils/download.py Normal file
View File

@ -0,0 +1,110 @@
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook

View File

@ -99,7 +99,7 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me
return loras_list_mult_choices_nums, slists_dict, ""
def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ):
def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None):
from mmgp import offload
sz = len(slists_dict["phase1"])
slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ]
@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st
def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
total_num_steps = len(timesteps)
model_switch_step = model_switch_step2 = None
for i, t in enumerate(timesteps):
if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i

View File

@ -9,7 +9,6 @@ import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
import requests
@ -257,7 +256,6 @@ VIDEO_READER_BACKENDS = {
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER

View File

@ -1,4 +1,3 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
@ -18,11 +17,12 @@ import os
import tempfile
import subprocess
import json
from functools import lru_cache
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from PIL import Image
video_info_cache = []
def seed_everything(seed: int):
random.seed(seed)
np.random.seed(seed)
@ -32,6 +32,14 @@ def seed_everything(seed: int):
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def has_video_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".mp4"]
def has_image_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
@ -77,7 +85,9 @@ def truncate_for_filesystem(s, max_bytes=255):
else: r = m - 1
return s[:l]
@lru_cache(maxsize=100)
def get_video_info(video_path):
global video_info_cache
import cv2
cap = cv2.VideoCapture(video_path)
@ -92,7 +102,7 @@ def get_video_info(video_path):
return fps, width, height, frame_count
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor:
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
cap = cv2.VideoCapture(file_name)
@ -100,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
raise ValueError(f"Cannot open video: {file_name}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = round(cap.get(cv2.CAP_PROP_FPS))
if target_fps is not None:
frame_no = round(target_fps * frame_no /fps)
# Handle out of bounds
if frame_no >= total_frames or frame_no < 0:
if return_last_if_missing:
@ -173,9 +186,15 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
def convert_tensor_to_image(t, frame_no = -1):
t = t[:, frame_no] if frame_no >= 0 else t
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
if len(t.shape) == 4:
t = t[:, frame_no]
if t.shape[0]== 1:
t = t.expand(3,-1,-1)
if mask_levels:
return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
else:
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
@ -186,6 +205,14 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
return frame_height, frame_width
def rgb_bw_to_rgba_mask(img, thresh=127):
a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha
out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent
out.putalpha(a) # white where alpha=255
return out
def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
@ -205,30 +232,64 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None:
def rescale_and_crop(img, w, h):
ow, oh = img.size
target_ratio = w / h
orig_ratio = ow / oh
if orig_ratio > target_ratio:
# Crop width first
nw = int(oh * target_ratio)
img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
else:
# Crop height first
nh = int(ow / target_ratio)
img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
return img.resize((w, h), Image.LANCZOS)
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None or fit_into_canvas == 2:
# return image_height, image_width
return canvas_height, canvas_width
if fit_into_canvas:
if fit_into_canvas == 1:
scale1 = min(canvas_height / image_height, canvas_width / image_width)
scale2 = min(canvas_width / image_height, canvas_height / image_width)
scale = max(scale1, scale2)
else:
else: #0 or #2 (crop)
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
new_height = round( image_height * scale / block_size) * block_size
new_width = round( image_width * scale / block_size) * block_size
return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
if fit_crop:
image = rescale_and_crop(image, canvas_width, canvas_height)
new_width, new_height = image.size
else:
image_width, image_height = image.size
new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
if rm_background:
session = new_session()
output_list =[]
output_mask_list =[]
for i, img in enumerate(img_list):
width, height = img.size
if fit_into_canvas:
resized_mask = None
if any_background_ref == 1 and i==0 or any_background_ref == 2:
if outpainting_dims is not None and background_ref_outpainted:
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height):
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
else:
resized_image =img
elif fit_into_canvas == 1:
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
scale = min(budget_height / height, budget_width / width)
new_height = int(height * scale)
@ -240,152 +301,112 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
resized_image = Image.fromarray(white_canvas)
else:
scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16)
new_width = int( round(width * scale / 16) * 16)
new_height = int( round(height * scale / block_size) * block_size)
new_width = int( round(width * scale / block_size) * block_size)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background and not (ignore_first and i == 0) :
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
return output_list
if return_tensor:
output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1))
else:
output_list.append(resized_image)
output_mask_list.append(resized_mask)
return output_list, output_mask_list
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
from shared.utils.utils import save_image
inpaint_color = canvas_tf_bg / 127.5 - 1
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img[:1]) if return_mask else None
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
else:
canvas_height, canvas_width = image_size
if full_frame:
new_height = canvas_height
new_width = canvas_width
top = left = 0
else:
# if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
# if fill_max and (canvas_width - new_width) < 16:
# new_width = canvas_width
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
if return_image:
return convert_tensor_to_image(ref_img), canvas
return ref_img.to(device), canvas
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
src_videos, src_masks = [], []
inpaint_color_compressed = guide_inpaint_color/127.5 - 1
prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
src_video, src_mask = cur_video_guide, cur_video_mask
if pre_video_guide is not None:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask:
src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1)
if any_guide_padding:
if src_video is None:
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
elif src_video.shape[1] < current_video_length:
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
elif src_video is not None:
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
src_video = src_video[:, :new_num_frames]
if any_mask and src_video is not None:
if src_mask is None:
src_mask = torch.ones_like(src_video[:1])
elif src_mask.shape[1] < src_video.shape[1]:
src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
else:
src_mask = src_mask[:, :src_video.shape[1]]
if src_video is not None :
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[:, pos:pos+1] = inpaint_color_compressed
if any_mask: src_mask[:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
if any_mask: src_mask[:, pos:pos+1] = msk
src_videos.append(src_video)
src_masks.append(src_mask)
return src_videos, src_masks

2044
wgp.py

File diff suppressed because it is too large Load Diff