mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge branch 'main' into feature_add-cuda-docker-runner
This commit is contained in:
		
						commit
						b28cb446bb
					
				
							
								
								
									
										230
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										230
									
								
								README.md
									
									
									
									
									
								
							@ -19,222 +19,96 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
			
		||||
 | 
			
		||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
			
		||||
 | 
			
		||||
## 🔥 Latest Updates :
 | 
			
		||||
### August 29 2025: WanGP v8.2 -  Here Goes Your Weekend
 | 
			
		||||
## 🔥 Latest Updates : 
 | 
			
		||||
### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release 
 | 
			
		||||
 | 
			
		||||
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
 | 
			
		||||
So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
 | 
			
		||||
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
 | 
			
		||||
 | 
			
		||||
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
 | 
			
		||||
In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
 | 
			
		||||
 | 
			
		||||
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
 | 
			
		||||
With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise) 
 | 
			
		||||
 | 
			
		||||
### August 24 2025: WanGP v8.1 -  the RAM Liberator
 | 
			
		||||
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
 | 
			
		||||
 | 
			
		||||
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
 | 
			
		||||
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
 | 
			
		||||
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
 | 
			
		||||
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
 | 
			
		||||
Also because I wanted to spoil you:
 | 
			
		||||
- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
 | 
			
		||||
 | 
			
		||||
### August 21 2025: WanGP v8.01 - the killer of seven
 | 
			
		||||
- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask.
 | 
			
		||||
 | 
			
		||||
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
 | 
			
		||||
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
 | 
			
		||||
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
 | 
			
		||||
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
 | 
			
		||||
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
 | 
			
		||||
 | 
			
		||||
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
 | 
			
		||||
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
 | 
			
		||||
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
 | 
			
		||||
*Update 8.72*: shadow drop of Qwen Edit Plus
 | 
			
		||||
*Update 8.73*: Qwen Preview & InfiniteTalk Start image 
 | 
			
		||||
*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video  
 | 
			
		||||
 | 
			
		||||
This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
 | 
			
		||||
### September 15 2025: WanGP v8.6 - Attack of the Clones
 | 
			
		||||
 | 
			
		||||
Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
 | 
			
		||||
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
 | 
			
		||||
 | 
			
		||||
Support:
 | 
			
		||||
- Video: x264, x264 lossless, x265
 | 
			
		||||
- Images: jpeg, png, webp, wbp lossless
 | 
			
		||||
Generation Settings are stored in each of the above regardless of the format (that was the hard part).
 | 
			
		||||
- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
 | 
			
		||||
 | 
			
		||||
Also you can now choose different output directories for images and videos.
 | 
			
		||||
For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on ....
 | 
			
		||||
If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
 | 
			
		||||
 | 
			
		||||
unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
 | 
			
		||||
*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
 | 
			
		||||
*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
 | 
			
		||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
 | 
			
		||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
 | 
			
		||||
- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
 | 
			
		||||
 | 
			
		||||
- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video.
 | 
			
		||||
 | 
			
		||||
### August 10 2025: WanGP v7.76 - Faster than the VAE ...
 | 
			
		||||
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
 | 
			
		||||
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
 | 
			
		||||
### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ?
 | 
			
		||||
 | 
			
		||||
*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
 | 
			
		||||
I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof.
 | 
			
		||||
 | 
			
		||||
### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
 | 
			
		||||
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
 | 
			
		||||
Otherwise in the news:
 | 
			
		||||
- **Cropped Input Image Prompts**: as quite often most *Image Prompts* provided (*Start Image, Input Video, Reference Image,  Control Video, ...*) rarely matched your requested *Output Resolution*. In that case I used the resolution you gave either as a *Pixels Budget* or as an *Outer Canvas* for the Generated Video. However in some occasion you really want the requested Output Resolution and nothing else. Besides some models deliver much better Generations if you stick to one of their supported resolutions. In order to address this need I have added a new Output Resolution choice in the *Configuration Tab*:  **Dimensions Correspond to the Ouput Weight & Height as the Prompt Images will be Cropped to fit Exactly these dimensins**. In short if needed the *Input Prompt Images* will be cropped (centered cropped for the moment). You will see this can make quite a difference for some models
 | 
			
		||||
 | 
			
		||||
Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
 | 
			
		||||
- *Qwen Edit* has now a new sub Tab called **Inpainting**, that lets you target with a brush which part of the *Image Prompt* you want to modify. This is quite convenient if you find that Qwen Edit modifies usually too many things. Of course, as there are more constraints for Qwen Edit don't be surprised if sometime it will return the original image unchanged. A piece of advise: describe in your *Text Prompt* where (for instance *left to the man*, *top*, ...) the parts that you want to modify are located.
 | 
			
		||||
 | 
			
		||||
### August 8 2025: WanGP v7.73 - Qwen Rebirth
 | 
			
		||||
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
 | 
			
		||||
The mask inpainting is fully compatible with *Matanyone Mask generator*: generate first an *Image Mask* with Matanyone, transfer it to the current Image Generator and modify the mask with the *Paint Brush*. Talking about matanyone I have fixed a bug that caused a mask degradation with long videos (now WanGP Matanyone is as good as the original app and still requires 3 times less VRAM)
 | 
			
		||||
 | 
			
		||||
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
 | 
			
		||||
- This **Inpainting Mask Editor** has been added also to *Vace Image Mode*. Vace is probably still one of best Image Editor today. Here is a very simple & efficient workflow that do marvels with Vace:
 | 
			
		||||
Select *Vace Cocktail > Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*.
 | 
			
		||||
 | 
			
		||||
Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
 | 
			
		||||
Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ...
 | 
			
		||||
 | 
			
		||||
*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
 | 
			
		||||
For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*" 
 | 
			
		||||
 | 
			
		||||
**update 8.55**: Flux Festival
 | 
			
		||||
- **Inpainting Mode** also added for *Flux Kontext*
 | 
			
		||||
- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available
 | 
			
		||||
- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768
 | 
			
		||||
 | 
			
		||||
### August 6 2025: WanGP v7.71 - Picky, picky
 | 
			
		||||
Good luck with finding your way through all the Flux models names !
 | 
			
		||||
 | 
			
		||||
This release comes with two new models :
 | 
			
		||||
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
 | 
			
		||||
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B"     )
 | 
			
		||||
### September 5 2025: WanGP v8.4 - Take me to Outer Space
 | 
			
		||||
You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie.
 | 
			
		||||
 | 
			
		||||
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least  720p.
 | 
			
		||||
I have made it easier to do just that with *Qwen Edit* and *Wan*:
 | 
			
		||||
- **End Frames can now be combined with Continue a Video** (and not just a Start Frame)
 | 
			
		||||
- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window
 | 
			
		||||
 | 
			
		||||
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
 | 
			
		||||
You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*)
 | 
			
		||||
 | 
			
		||||
The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement.
 | 
			
		||||
 | 
			
		||||
### August 4 2025: WanGP v7.6 - Remuxed
 | 
			
		||||
This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**.
 | 
			
		||||
So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video  *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close)
 | 
			
		||||
 | 
			
		||||
With this new version you won't have any excuse if there is no sound in your video.
 | 
			
		||||
 | 
			
		||||
*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
 | 
			
		||||
You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens).
 | 
			
		||||
 | 
			
		||||
Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
 | 
			
		||||
### September 2 2025: WanGP v8.31 - At last the pain stops
 | 
			
		||||
 | 
			
		||||
As a result you can apply a different sound source on each new video segment when doing a *Continue Video*.
 | 
			
		||||
- This single new feature should give you the strength to face all the potential bugs of this new release:
 | 
			
		||||
**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**  
 | 
			
		||||
 | 
			
		||||
For instance:
 | 
			
		||||
- first video part: use Multitalk with two people speaking
 | 
			
		||||
- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
 | 
			
		||||
- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
 | 
			
		||||
- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). 
 | 
			
		||||
 | 
			
		||||
To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
 | 
			
		||||
- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ...
 | 
			
		||||
 | 
			
		||||
Also:
 | 
			
		||||
- End Frame support added for LTX Video models
 | 
			
		||||
- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
 | 
			
		||||
- Flux Krea Dev support
 | 
			
		||||
 | 
			
		||||
### July 30 2025: WanGP v7.5:  Just another release ... Wan 2.2 part 2
 | 
			
		||||
Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
 | 
			
		||||
*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs*
 | 
			
		||||
 | 
			
		||||
Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
 | 
			
		||||
 | 
			
		||||
I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
 | 
			
		||||
 | 
			
		||||
And this time I really removed Vace Cocktail Light which gave a blurry vision.
 | 
			
		||||
 | 
			
		||||
### July 29 2025: WanGP v7.4:  Just another release ... Wan 2.2 Preview
 | 
			
		||||
Wan 2.2 is here.  The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
 | 
			
		||||
 | 
			
		||||
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
 | 
			
		||||
 | 
			
		||||
However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
 | 
			
		||||
 | 
			
		||||
Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
 | 
			
		||||
 | 
			
		||||
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
 | 
			
		||||
 | 
			
		||||
7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
 | 
			
		||||
 | 
			
		||||
### July 27 2025: WanGP v7.3 : Interlude
 | 
			
		||||
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
 | 
			
		||||
 | 
			
		||||
### July 26 2025: WanGP v7.2 : Ode to Vace
 | 
			
		||||
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
 | 
			
		||||
 | 
			
		||||
Here are some new Vace improvements:
 | 
			
		||||
- I have provided a default finetune named *Vace Cocktail*  which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able  to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
 | 
			
		||||
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
 | 
			
		||||
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
 | 
			
		||||
- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
 | 
			
		||||
 | 
			
		||||
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### July 21 2025: WanGP v7.12
 | 
			
		||||
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
 | 
			
		||||
 | 
			
		||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
 | 
			
		||||
 | 
			
		||||
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
 | 
			
		||||
 | 
			
		||||
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
 | 
			
		||||
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
 | 
			
		||||
 | 
			
		||||
- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
 | 
			
		||||
 | 
			
		||||
- Easier way to select video resolution
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
 | 
			
		||||
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
 | 
			
		||||
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
 | 
			
		||||
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
 | 
			
		||||
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
 | 
			
		||||
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
 | 
			
		||||
 | 
			
		||||
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
 | 
			
		||||
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
 | 
			
		||||
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
 | 
			
		||||
 | 
			
		||||
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
 | 
			
		||||
 | 
			
		||||
Also in the news:
 | 
			
		||||
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
 | 
			
		||||
- *Film Grain* post processing to add a vintage look at your video
 | 
			
		||||
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
 | 
			
		||||
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
 | 
			
		||||
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
 | 
			
		||||
 | 
			
		||||
So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models.
 | 
			
		||||
 | 
			
		||||
Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters.
 | 
			
		||||
 | 
			
		||||
The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence.
 | 
			
		||||
 | 
			
		||||
### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
 | 
			
		||||
**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
 | 
			
		||||
 | 
			
		||||
Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
 | 
			
		||||
 | 
			
		||||
And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
 | 
			
		||||
 | 
			
		||||
As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
 | 
			
		||||
 | 
			
		||||
But wait, there is more:
 | 
			
		||||
- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
 | 
			
		||||
- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
 | 
			
		||||
 | 
			
		||||
Also, of interest too:
 | 
			
		||||
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
 | 
			
		||||
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
 | 
			
		||||
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
 | 
			
		||||
 | 
			
		||||
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
 | 
			
		||||
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
 | 
			
		||||
- In one click use the newly generated video as a Control Video or Source Video to be continued
 | 
			
		||||
- Manage multiple settings for the same model and switch between them using a dropdown box
 | 
			
		||||
- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open
 | 
			
		||||
- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
 | 
			
		||||
 | 
			
		||||
Taking care of your life is not enough, you want new stuff to play with ?
 | 
			
		||||
- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
 | 
			
		||||
- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
 | 
			
		||||
- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
 | 
			
		||||
- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
 | 
			
		||||
- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video
 | 
			
		||||
- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer
 | 
			
		||||
- Choice of Wan Samplers / Schedulers
 | 
			
		||||
- More Lora formats support
 | 
			
		||||
 | 
			
		||||
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
 | 
			
		||||
 | 
			
		||||
See full changelog: **[Changelog](docs/CHANGELOG.md)**
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								configs/animate.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								configs/animate.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
{
 | 
			
		||||
  "_class_name": "WanModel",
 | 
			
		||||
  "_diffusers_version": "0.30.0",
 | 
			
		||||
  "dim": 5120,
 | 
			
		||||
  "eps": 1e-06,
 | 
			
		||||
  "ffn_dim": 13824,
 | 
			
		||||
  "freq_dim": 256,
 | 
			
		||||
  "in_dim": 36,
 | 
			
		||||
  "model_type": "i2v",
 | 
			
		||||
  "num_heads": 40,
 | 
			
		||||
  "num_layers": 40,
 | 
			
		||||
  "out_dim": 16,
 | 
			
		||||
  "text_len": 512,
 | 
			
		||||
  "motion_encoder_dim": 512
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										14
									
								
								configs/lucy_edit.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								configs/lucy_edit.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
{
 | 
			
		||||
  "_class_name": "WanModel",
 | 
			
		||||
  "_diffusers_version": "0.33.0",
 | 
			
		||||
  "dim": 3072,
 | 
			
		||||
  "eps": 1e-06,
 | 
			
		||||
  "ffn_dim": 14336,
 | 
			
		||||
  "freq_dim": 256,
 | 
			
		||||
  "in_dim": 96,
 | 
			
		||||
  "model_type": "ti2v2_2",
 | 
			
		||||
  "num_heads": 24,
 | 
			
		||||
  "num_layers": 30,
 | 
			
		||||
  "out_dim": 48,
 | 
			
		||||
  "text_len": 512
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								defaults/animate.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								defaults/animate.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 Animate",
 | 
			
		||||
        "architecture": "animate",
 | 
			
		||||
        "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
		"preload_URLs" :
 | 
			
		||||
		[
 | 
			
		||||
			"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors"
 | 
			
		||||
		],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -7,8 +7,6 @@
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
		"image_outputs": true,		
 | 
			
		||||
		"reference_image": true,		
 | 
			
		||||
		"flux-model": "flux-dev-kontext"		
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "add a hat",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								defaults/flux_dev_umo.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								defaults/flux_dev_umo.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Flux 1 Dev UMO 12B",
 | 
			
		||||
        "architecture": "flux",
 | 
			
		||||
        "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.",
 | 
			
		||||
        "URLs": "flux",
 | 
			
		||||
		"flux-model": "flux-dev-umo",		
 | 
			
		||||
		"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"],
 | 
			
		||||
		"resolutions":  [ ["1024x1024 (1:1)", "1024x1024"],
 | 
			
		||||
						["768x1024 (3:4)", "768x1024"],
 | 
			
		||||
						["1024x768 (4:3)", "1024x768"],
 | 
			
		||||
						["512x1024 (1:2)", "512x1024"],
 | 
			
		||||
						["1024x512 (2:1)", "1024x512"],
 | 
			
		||||
						["768x768 (1:1)", "768x768"],
 | 
			
		||||
						["768x512 (3:2)", "768x512"],
 | 
			
		||||
						["512x768 (2:3)", "512x768"]]
 | 
			
		||||
    },	
 | 
			
		||||
	"prompt": "the man is wearing a hat",
 | 
			
		||||
	"embedded_guidance_scale": 4,
 | 
			
		||||
    "resolution": "768x768",
 | 
			
		||||
    "batch_size": 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	
 | 
			
		||||
@ -2,15 +2,13 @@
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Flux 1 Dev USO 12B",
 | 
			
		||||
        "architecture": "flux",
 | 
			
		||||
        "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).",
 | 
			
		||||
        "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).",
 | 
			
		||||
		"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
 | 
			
		||||
        "URLs": "flux",
 | 
			
		||||
		"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
 | 
			
		||||
		"image_outputs": true,		
 | 
			
		||||
		"reference_image": true,		
 | 
			
		||||
		"flux-model": "flux-dev-uso"		
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "add a hat",
 | 
			
		||||
	"prompt": "the man is wearing a hat",
 | 
			
		||||
	"embedded_guidance_scale": 4,
 | 
			
		||||
    "resolution": "1024x1024",
 | 
			
		||||
    "batch_size": 1
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								defaults/flux_srpo.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								defaults/flux_srpo.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Flux 1 SRPO Dev 12B",
 | 
			
		||||
        "architecture": "flux",
 | 
			
		||||
        "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "flux-model": "flux-dev"
 | 
			
		||||
    },
 | 
			
		||||
    "prompt": "draw a hat",
 | 
			
		||||
    "resolution": "1024x1024",
 | 
			
		||||
    "batch_size": 1
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								defaults/flux_srpo_uso.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								defaults/flux_srpo_uso.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Flux 1 SRPO USO 12B",
 | 
			
		||||
        "architecture": "flux",
 | 
			
		||||
        "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process",
 | 
			
		||||
		"modules": [ "flux_dev_uso"],
 | 
			
		||||
        "URLs": "flux_srpo",
 | 
			
		||||
		"loras": "flux_dev_uso",
 | 
			
		||||
		"flux-model": "flux-dev-uso"		
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "the man is wearing a hat",
 | 
			
		||||
	"embedded_guidance_scale": 4,
 | 
			
		||||
    "resolution": "1024x1024",
 | 
			
		||||
    "batch_size": 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	
 | 
			
		||||
							
								
								
									
										19
									
								
								defaults/lucy_edit.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								defaults/lucy_edit.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 Lucy Edit 5B",
 | 
			
		||||
        "architecture": "lucy_edit",
 | 
			
		||||
        "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "change the clothes to red",	
 | 
			
		||||
    "video_length": 81,
 | 
			
		||||
    "guidance_scale": 5,
 | 
			
		||||
    "flow_shift": 5,
 | 
			
		||||
    "num_inference_steps": 30,
 | 
			
		||||
    "resolution": "1280x720"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										16
									
								
								defaults/lucy_edit_fastwan.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								defaults/lucy_edit_fastwan.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 FastWan Lucy Edit 5B",
 | 
			
		||||
        "architecture": "lucy_edit",
 | 
			
		||||
        "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
 | 
			
		||||
        "URLs": "lucy_edit",
 | 
			
		||||
        "group": "wan2_2",
 | 
			
		||||
		"loras": "ti2v_2_2_fastwan"
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "change the clothes to red",
 | 
			
		||||
    "video_length": 81,
 | 
			
		||||
    "guidance_scale": 1,
 | 
			
		||||
    "flow_shift": 3,
 | 
			
		||||
    "num_inference_steps": 5,
 | 
			
		||||
    "resolution": "1280x720"
 | 
			
		||||
}
 | 
			
		||||
@ -7,11 +7,10 @@
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
		"preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"],
 | 
			
		||||
        "attention": {
 | 
			
		||||
            "<89": "sdpa"
 | 
			
		||||
        },
 | 
			
		||||
        "reference_image": true,
 | 
			
		||||
        "image_outputs": true
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    "prompt": "add a hat",
 | 
			
		||||
    "resolution": "1280x720",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								defaults/qwen_image_edit_plus_20B.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								defaults/qwen_image_edit_plus_20B.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Qwen Image Edit Plus 20B",
 | 
			
		||||
        "architecture": "qwen_image_edit_plus_20B",
 | 
			
		||||
        "description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "preload_URLs": "qwen_image_edit_20B",
 | 
			
		||||
        "attention": {
 | 
			
		||||
            "<89": "sdpa"
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    "prompt": "add a hat",
 | 
			
		||||
    "resolution": "1024x1024",
 | 
			
		||||
    "batch_size": 1
 | 
			
		||||
}
 | 
			
		||||
@ -4,7 +4,7 @@
 | 
			
		||||
		"name": "Wan2.1 Standin 14B",
 | 
			
		||||
		"modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]],
 | 
			
		||||
		"architecture" : "standin",
 | 
			
		||||
		"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of person face to transfer this person in the Video.",
 | 
			
		||||
		"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of a person face to transfer this person in the Video.",
 | 
			
		||||
		"URLs":  "t2v"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -7,6 +7,7 @@
 | 
			
		||||
		"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    },
 | 
			
		||||
	"prompt" : "Put the person into a clown outfit.", 
 | 
			
		||||
    "video_length": 121,
 | 
			
		||||
    "guidance_scale": 1,
 | 
			
		||||
    "flow_shift": 3,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								defaults/vace_fun_14B_2_2.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								defaults/vace_fun_14B_2_2.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 Vace Fun 14B",
 | 
			
		||||
        "architecture": "vace_14B",
 | 
			
		||||
        "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "URLs2": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    },
 | 
			
		||||
    "guidance_phases": 2,
 | 
			
		||||
    "num_inference_steps": 30,
 | 
			
		||||
    "guidance_scale": 1,
 | 
			
		||||
    "guidance2_scale": 1,
 | 
			
		||||
    "flow_shift": 2,
 | 
			
		||||
    "switch_threshold": 875
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								defaults/vace_fun_14B_cocktail_2_2.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								defaults/vace_fun_14B_cocktail_2_2.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 Vace Fun Cocktail 14B",
 | 
			
		||||
        "architecture": "vace_14B",
 | 
			
		||||
        "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2",
 | 
			
		||||
        "URLs": "vace_fun_14B_2_2",
 | 
			
		||||
        "URLs2": "vace_fun_14B_2_2",
 | 
			
		||||
        "loras": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
        "loras_multipliers": [
 | 
			
		||||
            1,
 | 
			
		||||
            0.2,
 | 
			
		||||
            0.5,
 | 
			
		||||
            0.5
 | 
			
		||||
        ],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    },
 | 
			
		||||
    "guidance_phases": 2,
 | 
			
		||||
    "num_inference_steps": 10,
 | 
			
		||||
    "guidance_scale": 1,
 | 
			
		||||
    "guidance2_scale": 1,
 | 
			
		||||
    "flow_shift": 2,
 | 
			
		||||
    "switch_threshold": 875
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										146
									
								
								docs/AMD-INSTALLATION.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								docs/AMD-INSTALLATION.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,146 @@
 | 
			
		||||
# Installation Guide
 | 
			
		||||
 | 
			
		||||
This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs
 | 
			
		||||
running under Windows. 
 | 
			
		||||
 | 
			
		||||
tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair).
 | 
			
		||||
 | 
			
		||||
Currently supported (but not necessary tested):
 | 
			
		||||
 | 
			
		||||
**gfx110x**:
 | 
			
		||||
 | 
			
		||||
* Radeon RX 7600
 | 
			
		||||
* Radeon RX 7700 XT
 | 
			
		||||
* Radeon RX 7800 XT
 | 
			
		||||
* Radeon RX 7900 GRE
 | 
			
		||||
* Radeon RX 7900 XT
 | 
			
		||||
* Radeon RX 7900 XTX
 | 
			
		||||
 | 
			
		||||
**gfx1151**:
 | 
			
		||||
 | 
			
		||||
* Ryzen 7000 series APUs (Phoenix)
 | 
			
		||||
* Ryzen Z1 (e.g., handheld devices like the ROG Ally)
 | 
			
		||||
 | 
			
		||||
**gfx1201**:
 | 
			
		||||
 | 
			
		||||
* Ryzen 8000 series APUs (Strix Point) 
 | 
			
		||||
* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Requirements
 | 
			
		||||
 | 
			
		||||
- Python 3.11 (3.12 might work, 3.10 definately will not!)
 | 
			
		||||
 | 
			
		||||
## Installation Environment
 | 
			
		||||
 | 
			
		||||
This installation uses PyTorch 2.7.0 because that's what currently available in
 | 
			
		||||
terms of pre-compiled wheels.
 | 
			
		||||
 | 
			
		||||
### Installing Python
 | 
			
		||||
 | 
			
		||||
Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test.
 | 
			
		||||
 | 
			
		||||
After installing, make sure `python --version` works in your terminal and returns 3.11.x
 | 
			
		||||
 | 
			
		||||
If not, you probably need to fix your PATH. Go to:
 | 
			
		||||
 | 
			
		||||
* Windows + Pause/Break
 | 
			
		||||
* Advanced System Settings
 | 
			
		||||
* Environment Variables
 | 
			
		||||
* Edit your `Path` under User Variables
 | 
			
		||||
 | 
			
		||||
Example correct entries:
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\
 | 
			
		||||
C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\
 | 
			
		||||
C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
If that doesnt work, scream into a bucket.
 | 
			
		||||
 | 
			
		||||
### Installing Git
 | 
			
		||||
 | 
			
		||||
Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Install (Windows, using `venv`)
 | 
			
		||||
 | 
			
		||||
### Step 1: Download and Set Up Environment
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
:: Navigate to your desired install directory
 | 
			
		||||
cd \your-path-to-wan2gp
 | 
			
		||||
 | 
			
		||||
:: Clone the repository
 | 
			
		||||
git clone https://github.com/deepbeepmeep/Wan2GP.git
 | 
			
		||||
cd Wan2GP
 | 
			
		||||
 | 
			
		||||
:: Create virtual environment using Python 3.10.9
 | 
			
		||||
python -m venv wan2gp-env
 | 
			
		||||
 | 
			
		||||
:: Activate the virtual environment
 | 
			
		||||
wan2gp-env\Scripts\activate
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 2: Install PyTorch
 | 
			
		||||
 | 
			
		||||
The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says:
 | 
			
		||||
 | 
			
		||||
**Pytorch wheels for gfx110x, gfx1151, and gfx1201**
 | 
			
		||||
 | 
			
		||||
Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming.
 | 
			
		||||
 | 
			
		||||
Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter.
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
pip install ^
 | 
			
		||||
    https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^
 | 
			
		||||
    https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^
 | 
			
		||||
    https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 3: Install Dependencies
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
:: Install core dependencies
 | 
			
		||||
pip install -r requirements.txt
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Attention Modes
 | 
			
		||||
 | 
			
		||||
WanGP supports several attention implementations, only one of which will work for you:
 | 
			
		||||
 | 
			
		||||
- **SDPA** (default): Available by default with PyTorch.  This uses the built-in aotriton accel library, so is actually pretty fast.
 | 
			
		||||
 | 
			
		||||
## Performance Profiles
 | 
			
		||||
 | 
			
		||||
Choose a profile based on your hardware:
 | 
			
		||||
 | 
			
		||||
- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model
 | 
			
		||||
- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement
 | 
			
		||||
 | 
			
		||||
## Running Wan2GP
 | 
			
		||||
 | 
			
		||||
In future, you will have to do this:
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
cd \path-to\wan2gp
 | 
			
		||||
wan2gp\Scripts\activate.bat
 | 
			
		||||
python wgp.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment)
 | 
			
		||||
 | 
			
		||||
## Troubleshooting
 | 
			
		||||
 | 
			
		||||
- If you use a HIGH VRAM mode, don't be a fool.  Make sure you use VAE Tiled Decoding.
 | 
			
		||||
 | 
			
		||||
### Memory Issues
 | 
			
		||||
 | 
			
		||||
- Use lower resolution or shorter videos
 | 
			
		||||
- Enable quantization (default)
 | 
			
		||||
- Use Profile 4 for lower VRAM usage
 | 
			
		||||
- Consider using 1.3B models instead of 14B models
 | 
			
		||||
 | 
			
		||||
For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) 
 | 
			
		||||
@ -1,20 +1,154 @@
 | 
			
		||||
# Changelog
 | 
			
		||||
 | 
			
		||||
## 🔥 Latest News
 | 
			
		||||
### July 21 2025: WanGP v7.1 
 | 
			
		||||
### August 29 2025: WanGP v8.21 -  Here Goes Your Weekend
 | 
			
		||||
 | 
			
		||||
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
 | 
			
		||||
 | 
			
		||||
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
 | 
			
		||||
 | 
			
		||||
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
 | 
			
		||||
 | 
			
		||||
### August 24 2025: WanGP v8.1 -  the RAM Liberator
 | 
			
		||||
 | 
			
		||||
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP 
 | 
			
		||||
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
 | 
			
		||||
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
 | 
			
		||||
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
 | 
			
		||||
 | 
			
		||||
### August 21 2025: WanGP v8.01 - the killer of seven
 | 
			
		||||
 | 
			
		||||
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
 | 
			
		||||
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
 | 
			
		||||
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
 | 
			
		||||
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading 
 | 
			
		||||
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
 | 
			
		||||
 | 
			
		||||
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
 | 
			
		||||
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
 | 
			
		||||
 | 
			
		||||
This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
 | 
			
		||||
 | 
			
		||||
Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
 | 
			
		||||
 | 
			
		||||
Support:
 | 
			
		||||
- Video: x264, x264 lossless, x265
 | 
			
		||||
- Images: jpeg, png, webp, wbp lossless
 | 
			
		||||
Generation Settings are stored in each of the above regardless of the format (that was the hard part).
 | 
			
		||||
 | 
			
		||||
Also you can now choose different output directories for images and videos.
 | 
			
		||||
 | 
			
		||||
unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. 
 | 
			
		||||
*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
 | 
			
		||||
*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
 | 
			
		||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
 | 
			
		||||
- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### August 10 2025: WanGP v7.76 - Faster than the VAE ...
 | 
			
		||||
We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... 
 | 
			
		||||
Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
 | 
			
		||||
 | 
			
		||||
*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
 | 
			
		||||
 | 
			
		||||
### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
 | 
			
		||||
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
 | 
			
		||||
 | 
			
		||||
Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
 | 
			
		||||
 | 
			
		||||
### August 8 2025: WanGP v7.73 - Qwen Rebirth
 | 
			
		||||
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
 | 
			
		||||
 | 
			
		||||
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) 
 | 
			
		||||
 | 
			
		||||
Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
 | 
			
		||||
 | 
			
		||||
*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### August 6 2025: WanGP v7.71 - Picky, picky
 | 
			
		||||
 | 
			
		||||
This release comes with two new models :
 | 
			
		||||
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
 | 
			
		||||
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B"     )
 | 
			
		||||
 | 
			
		||||
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least  720p. 
 | 
			
		||||
 | 
			
		||||
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### August 4 2025: WanGP v7.6 - Remuxed
 | 
			
		||||
 | 
			
		||||
With this new version you won't have any excuse if there is no sound in your video.
 | 
			
		||||
 | 
			
		||||
*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
 | 
			
		||||
 | 
			
		||||
Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
 | 
			
		||||
 | 
			
		||||
As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. 
 | 
			
		||||
 | 
			
		||||
For instance:
 | 
			
		||||
- first video part: use Multitalk with two people speaking
 | 
			
		||||
- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
 | 
			
		||||
- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
 | 
			
		||||
 | 
			
		||||
To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
 | 
			
		||||
 | 
			
		||||
Also:
 | 
			
		||||
- End Frame support added for LTX Video models
 | 
			
		||||
- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
 | 
			
		||||
- Flux Krea Dev support
 | 
			
		||||
 | 
			
		||||
### July 30 2025: WanGP v7.5:  Just another release ... Wan 2.2 part 2
 | 
			
		||||
Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
 | 
			
		||||
 | 
			
		||||
Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
 | 
			
		||||
 | 
			
		||||
I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
 | 
			
		||||
 | 
			
		||||
And this time I really removed Vace Cocktail Light which gave a blurry vision.
 | 
			
		||||
 | 
			
		||||
### July 29 2025: WanGP v7.4:  Just another release ... Wan 2.2 Preview
 | 
			
		||||
Wan 2.2 is here.  The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
 | 
			
		||||
 | 
			
		||||
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
 | 
			
		||||
 | 
			
		||||
However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
 | 
			
		||||
 | 
			
		||||
Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
 | 
			
		||||
 | 
			
		||||
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
 | 
			
		||||
 | 
			
		||||
7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
 | 
			
		||||
 | 
			
		||||
### July 27 2025: WanGP v7.3 : Interlude
 | 
			
		||||
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
 | 
			
		||||
 | 
			
		||||
### July 26 2025: WanGP v7.2 : Ode to Vace
 | 
			
		||||
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
 | 
			
		||||
 | 
			
		||||
Here are some new Vace improvements:
 | 
			
		||||
- I have provided a default finetune named *Vace Cocktail*  which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able  to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
 | 
			
		||||
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). 
 | 
			
		||||
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
 | 
			
		||||
- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
 | 
			
		||||
 | 
			
		||||
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### July 21 2025: WanGP v7.12
 | 
			
		||||
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
 | 
			
		||||
 | 
			
		||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
 | 
			
		||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
 | 
			
		||||
 | 
			
		||||
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. 
 | 
			
		||||
 | 
			
		||||
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
 | 
			
		||||
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
 | 
			
		||||
 | 
			
		||||
And Also:
 | 
			
		||||
- easier way to select video resolution 
 | 
			
		||||
- started to optimize Matanyone to reduce VRAM requirements
 | 
			
		||||
- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
 | 
			
		||||
 | 
			
		||||
- Easier way to select video resolution 
 | 
			
		||||
 | 
			
		||||
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
 | 
			
		||||
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ conda activate wan2gp
 | 
			
		||||
### Step 2: Install PyTorch
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
# Install PyTorch 2.7.0 with CUDA 12.4
 | 
			
		||||
# Install PyTorch 2.7.0 with CUDA 12.8
 | 
			
		||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,27 +13,52 @@ class family_handler():
 | 
			
		||||
        flux_schnell = flux_model == "flux-schnell" 
 | 
			
		||||
        flux_chroma = flux_model == "flux-chroma" 
 | 
			
		||||
        flux_uso = flux_model == "flux-dev-uso"
 | 
			
		||||
        model_def_output = {
 | 
			
		||||
        flux_umo = flux_model == "flux-dev-umo"
 | 
			
		||||
        flux_kontext = flux_model == "flux-dev-kontext"
 | 
			
		||||
        
 | 
			
		||||
        extra_model_def = {
 | 
			
		||||
            "image_outputs" : True,
 | 
			
		||||
            "no_negative_prompt" : not flux_chroma,
 | 
			
		||||
        }
 | 
			
		||||
        if flux_chroma:
 | 
			
		||||
            model_def_output["guidance_max_phases"] = 1
 | 
			
		||||
            extra_model_def["guidance_max_phases"] = 1
 | 
			
		||||
        elif not flux_schnell:
 | 
			
		||||
            model_def_output["embedded_guidance"] = True
 | 
			
		||||
            extra_model_def["embedded_guidance"] = True
 | 
			
		||||
        if flux_uso :
 | 
			
		||||
            model_def_output["any_image_refs_relative_size"] = True
 | 
			
		||||
            model_def_output["no_background_removal"] = True
 | 
			
		||||
 | 
			
		||||
            model_def_output["image_ref_choices"] = {
 | 
			
		||||
                "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"),
 | 
			
		||||
                            ("Up to two Images are Style Images", "IJ")],
 | 
			
		||||
                "default": "I",
 | 
			
		||||
                "letters_filter": "IJ",
 | 
			
		||||
            extra_model_def["any_image_refs_relative_size"] = True
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
 | 
			
		||||
                            ("Up to two Images are Style Images", "KIJ")],
 | 
			
		||||
                "default": "KI",
 | 
			
		||||
                "letters_filter": "KIJ",
 | 
			
		||||
                "label": "Reference Images / Style Images"
 | 
			
		||||
            }
 | 
			
		||||
        
 | 
			
		||||
        if flux_kontext:
 | 
			
		||||
            extra_model_def["inpaint_support"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("None", ""),
 | 
			
		||||
                    ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
 | 
			
		||||
                    ("Conditional Images are People / Objects", "I"),
 | 
			
		||||
                    ],
 | 
			
		||||
                "letters_filter": "KI",
 | 
			
		||||
            }
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" 
 | 
			
		||||
        elif flux_umo:
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("Conditional Images are People / Objects", "I"),
 | 
			
		||||
                    ],
 | 
			
		||||
                "letters_filter": "I",
 | 
			
		||||
                "visible": False
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return model_def_output
 | 
			
		||||
 | 
			
		||||
        extra_model_def["fit_into_canvas_image_refs"] = 0
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_supported_types():
 | 
			
		||||
@ -82,7 +107,7 @@ class family_handler():
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
 | 
			
		||||
        from .flux_main  import model_factory
 | 
			
		||||
 | 
			
		||||
        flux_model = model_factory(
 | 
			
		||||
@ -107,15 +132,38 @@ class family_handler():
 | 
			
		||||
            pipe["feature_embedder"] = flux_model.feature_embedder 
 | 
			
		||||
        return flux_model, pipe
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
 | 
			
		||||
        flux_model = model_def.get("flux-model", "flux-dev")
 | 
			
		||||
        flux_uso = flux_model == "flux-dev-uso"
 | 
			
		||||
        if flux_uso and settings_version < 2.29:
 | 
			
		||||
            video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
			
		||||
            if "I" in video_prompt_type:
 | 
			
		||||
                video_prompt_type = video_prompt_type.replace("I", "KI")
 | 
			
		||||
                ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.34:
 | 
			
		||||
            ui_defaults["denoising_strength"] = 1.
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        flux_model = model_def.get("flux-model", "flux-dev")
 | 
			
		||||
        flux_uso = flux_model == "flux-dev-uso"
 | 
			
		||||
        flux_umo = flux_model == "flux-dev-umo"
 | 
			
		||||
        flux_kontext = flux_model == "flux-dev-kontext"
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
            "embedded_guidance":  2.5,
 | 
			
		||||
        })            
 | 
			
		||||
        if model_def.get("reference_image", False):
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "I" if flux_uso else "KI",
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        if flux_kontext or flux_uso:
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "KI",
 | 
			
		||||
                "denoising_strength": 1.,
 | 
			
		||||
            })
 | 
			
		||||
        elif flux_umo:
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "I",
 | 
			
		||||
                "remove_background_images_ref": 0,
 | 
			
		||||
            })
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,9 @@ from shared.utils.utils import calculate_new_dimensions
 | 
			
		||||
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
 | 
			
		||||
from .modules.layers import get_linear_split_map
 | 
			
		||||
from transformers import SiglipVisionModel, SiglipImageProcessor
 | 
			
		||||
import torchvision.transforms.functional as TVF
 | 
			
		||||
import math
 | 
			
		||||
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
 | 
			
		||||
 | 
			
		||||
from .util import (
 | 
			
		||||
    aspect_ratio_to_height_width,
 | 
			
		||||
@ -20,6 +23,35 @@ from .util import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
 | 
			
		||||
    # 获取原始图像的宽度和高度
 | 
			
		||||
    image_w, image_h = raw_image.size
 | 
			
		||||
 | 
			
		||||
    # 计算长边和短边
 | 
			
		||||
    if image_w >= image_h:
 | 
			
		||||
        new_w = long_size
 | 
			
		||||
        new_h = int((long_size / image_w) * image_h)
 | 
			
		||||
    else:
 | 
			
		||||
        new_h = long_size
 | 
			
		||||
        new_w = int((long_size / image_h) * image_w)
 | 
			
		||||
 | 
			
		||||
    # 按新的宽高进行等比例缩放
 | 
			
		||||
    raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
 | 
			
		||||
    target_w = new_w // 16 * 16
 | 
			
		||||
    target_h = new_h // 16 * 16
 | 
			
		||||
 | 
			
		||||
    # 计算裁剪的起始坐标以实现中心裁剪
 | 
			
		||||
    left = (new_w - target_w) // 2
 | 
			
		||||
    top = (new_h - target_h) // 2
 | 
			
		||||
    right = left + target_w
 | 
			
		||||
    bottom = top + target_h
 | 
			
		||||
 | 
			
		||||
    # 进行中心裁剪
 | 
			
		||||
    raw_image = raw_image.crop((left, top, right, bottom))
 | 
			
		||||
 | 
			
		||||
    # 转换为 RGB 模式
 | 
			
		||||
    raw_image = raw_image.convert("RGB")
 | 
			
		||||
    return raw_image
 | 
			
		||||
 | 
			
		||||
def stitch_images(img1, img2):
 | 
			
		||||
    # Resize img2 to match img1's height
 | 
			
		||||
@ -64,7 +96,7 @@ class model_factory:
 | 
			
		||||
        # self.name= "flux-schnell"
 | 
			
		||||
        source =  model_def.get("source", None)
 | 
			
		||||
        self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
 | 
			
		||||
 | 
			
		||||
        self.model_def = model_def 
 | 
			
		||||
        self.vae = load_ae(self.name, device=torch_device)
 | 
			
		||||
 | 
			
		||||
        siglip_processor = siglip_model = feature_embedder = None
 | 
			
		||||
@ -106,10 +138,12 @@ class model_factory:
 | 
			
		||||
    def generate(
 | 
			
		||||
            self,
 | 
			
		||||
            seed: int | None = None,
 | 
			
		||||
            input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
 | 
			
		||||
            input_prompt: str = "replace the logo with the text 'Black Forest Labs'",            
 | 
			
		||||
            n_prompt: str = None,
 | 
			
		||||
            sampling_steps: int = 20,
 | 
			
		||||
            input_ref_images = None,
 | 
			
		||||
            input_frames= None,
 | 
			
		||||
            input_masks= None,
 | 
			
		||||
            width= 832,
 | 
			
		||||
            height=480,
 | 
			
		||||
            embedded_guidance_scale: float = 2.5,
 | 
			
		||||
@ -120,7 +154,8 @@ class model_factory:
 | 
			
		||||
            batch_size = 1,
 | 
			
		||||
            video_prompt_type = "",
 | 
			
		||||
            joint_pass = False,
 | 
			
		||||
            image_refs_relative_size = 100,       
 | 
			
		||||
            image_refs_relative_size = 100,
 | 
			
		||||
            denoising_strength = 1.,
 | 
			
		||||
            **bbargs
 | 
			
		||||
    ):
 | 
			
		||||
            if self._interrupt:
 | 
			
		||||
@ -129,11 +164,17 @@ class model_factory:
 | 
			
		||||
            if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
 | 
			
		||||
            device="cuda"
 | 
			
		||||
            flux_dev_uso = self.name in ['flux-dev-uso']
 | 
			
		||||
            image_stiching =  not self.name in ['flux-dev-uso']
 | 
			
		||||
            flux_dev_umo = self.name in ['flux-dev-umo']
 | 
			
		||||
            latent_stiching =  self.name in ['flux-dev-uso', 'flux-dev-umo'] 
 | 
			
		||||
 | 
			
		||||
            lock_dimensions=  False
 | 
			
		||||
 | 
			
		||||
            input_ref_images = [] if input_ref_images is None else input_ref_images[:]
 | 
			
		||||
            if flux_dev_umo:
 | 
			
		||||
                ref_long_side = 512 if len(input_ref_images) <= 1 else 320
 | 
			
		||||
                input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images]
 | 
			
		||||
                lock_dimensions = True
 | 
			
		||||
            ref_style_imgs = []
 | 
			
		||||
 | 
			
		||||
            if "I" in video_prompt_type and len(input_ref_images) > 0: 
 | 
			
		||||
                if flux_dev_uso :
 | 
			
		||||
                    if "J" in video_prompt_type:
 | 
			
		||||
@ -142,33 +183,28 @@ class model_factory:
 | 
			
		||||
                    elif len(input_ref_images) > 1 :
 | 
			
		||||
                        ref_style_imgs = input_ref_images[-1:]
 | 
			
		||||
                        input_ref_images = input_ref_images[:-1]
 | 
			
		||||
                if image_stiching:
 | 
			
		||||
 | 
			
		||||
                if latent_stiching:
 | 
			
		||||
                    # latents stiching with resize 
 | 
			
		||||
                    if not lock_dimensions :
 | 
			
		||||
                        for i in range(len(input_ref_images)):
 | 
			
		||||
                            w, h = input_ref_images[i].size
 | 
			
		||||
                            image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0)
 | 
			
		||||
                            input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                else:
 | 
			
		||||
                    # image stiching method
 | 
			
		||||
                    stiched = input_ref_images[0]
 | 
			
		||||
                    if "K" in video_prompt_type :
 | 
			
		||||
                        w, h = input_ref_images[0].size
 | 
			
		||||
                        height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
			
		||||
 | 
			
		||||
                    for new_img in input_ref_images[1:]:
 | 
			
		||||
                        stiched = stitch_images(stiched, new_img)
 | 
			
		||||
                    input_ref_images  = [stiched]
 | 
			
		||||
                else:
 | 
			
		||||
                    first_ref = 0
 | 
			
		||||
                    if "K" in video_prompt_type:
 | 
			
		||||
                        # image latents tiling method
 | 
			
		||||
                        w, h = input_ref_images[0].size
 | 
			
		||||
                        height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
			
		||||
                        input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                        first_ref = 1
 | 
			
		||||
 | 
			
		||||
                    for i in range(first_ref,len(input_ref_images)):
 | 
			
		||||
                        w, h = input_ref_images[i].size
 | 
			
		||||
                        image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
 | 
			
		||||
                        input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            elif input_frames is not None:
 | 
			
		||||
                input_ref_images = [convert_tensor_to_image(input_frames) ] 
 | 
			
		||||
            else:
 | 
			
		||||
                input_ref_images = None
 | 
			
		||||
            image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) 
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
            if flux_dev_uso :
 | 
			
		||||
            if self.name in ['flux-dev-uso', 'flux-dev-umo']  :
 | 
			
		||||
                inp, height, width = prepare_multi_ip(
 | 
			
		||||
                    ae=self.vae,
 | 
			
		||||
                    img_cond_list=input_ref_images,
 | 
			
		||||
@ -187,6 +223,7 @@ class model_factory:
 | 
			
		||||
                    bs=batch_size,
 | 
			
		||||
                    seed=seed,
 | 
			
		||||
                    device=device,
 | 
			
		||||
                    img_mask=image_mask,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
 | 
			
		||||
@ -208,13 +245,19 @@ class model_factory:
 | 
			
		||||
                return unpack(x.float(), height, width) 
 | 
			
		||||
 | 
			
		||||
            # denoise initial noise
 | 
			
		||||
            x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
 | 
			
		||||
            x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength)
 | 
			
		||||
            if x==None: return None
 | 
			
		||||
            # decode latents to pixel space
 | 
			
		||||
            x = unpack_latent(x)
 | 
			
		||||
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
 | 
			
		||||
                x = self.vae.decode(x)
 | 
			
		||||
 | 
			
		||||
            if image_mask is not None:
 | 
			
		||||
                from shared.utils.utils import convert_image_to_tensor
 | 
			
		||||
                img_msk_rebuilt = inp["img_msk_rebuilt"]
 | 
			
		||||
                img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) 
 | 
			
		||||
                x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt 
 | 
			
		||||
 | 
			
		||||
            x = x.clamp(-1, 1)
 | 
			
		||||
            x = x.transpose(0, 1)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
@ -190,6 +190,21 @@ class Flux(nn.Module):
 | 
			
		||||
                            v = swap_scale_shift(v)
 | 
			
		||||
                        k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")            
 | 
			
		||||
                new_sd[k] = v
 | 
			
		||||
        # elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."):
 | 
			
		||||
        #     for k,v in sd.items():
 | 
			
		||||
        #         if "double" in k:
 | 
			
		||||
        #             k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_")
 | 
			
		||||
        #             k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_")
 | 
			
		||||
        #             k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_")
 | 
			
		||||
        #             k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_")
 | 
			
		||||
        #         else:
 | 
			
		||||
        #             k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_")
 | 
			
		||||
        #             k = k.replace(".processor.proj_lora.", ".linear2.lora_")
 | 
			
		||||
 | 
			
		||||
        #         k = "diffusion_model." + k
 | 
			
		||||
        #         new_sd[k] = v
 | 
			
		||||
        #     from mmgp import safetensors2
 | 
			
		||||
        #     safetensors2.torch_write_file(new_sd, "fff.safetensors")
 | 
			
		||||
        else:
 | 
			
		||||
            new_sd = sd
 | 
			
		||||
        return new_sd    
 | 
			
		||||
 | 
			
		||||
@ -138,10 +138,12 @@ def prepare_kontext(
 | 
			
		||||
    target_width: int | None = None,
 | 
			
		||||
    target_height: int | None = None,
 | 
			
		||||
    bs: int = 1,
 | 
			
		||||
 | 
			
		||||
    img_mask = None,
 | 
			
		||||
) -> tuple[dict[str, Tensor], int, int]:
 | 
			
		||||
    # load and encode the conditioning image
 | 
			
		||||
 | 
			
		||||
    res_match_output = img_mask is not None
 | 
			
		||||
 | 
			
		||||
    img_cond_seq = None
 | 
			
		||||
    img_cond_seq_ids = None
 | 
			
		||||
    if img_cond_list == None: img_cond_list = []
 | 
			
		||||
@ -150,10 +152,11 @@ def prepare_kontext(
 | 
			
		||||
    for cond_no, img_cond in enumerate(img_cond_list): 
 | 
			
		||||
        width, height = img_cond.size
 | 
			
		||||
        aspect_ratio = width / height
 | 
			
		||||
 | 
			
		||||
        # Kontext is trained on specific resolutions, using one of them is recommended
 | 
			
		||||
        _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
 | 
			
		||||
 | 
			
		||||
        if res_match_output:
 | 
			
		||||
            width, height = target_width, target_height
 | 
			
		||||
        else:
 | 
			
		||||
            # Kontext is trained on specific resolutions, using one of them is recommended
 | 
			
		||||
            _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
 | 
			
		||||
        width = 2 * int(width / 16)
 | 
			
		||||
        height = 2 * int(height / 16)
 | 
			
		||||
 | 
			
		||||
@ -194,6 +197,19 @@ def prepare_kontext(
 | 
			
		||||
        "img_cond_seq": img_cond_seq,
 | 
			
		||||
        "img_cond_seq_ids": img_cond_seq_ids,
 | 
			
		||||
    }
 | 
			
		||||
    if img_mask is not None:
 | 
			
		||||
        from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
 | 
			
		||||
        # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
        image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS))
 | 
			
		||||
        image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
 | 
			
		||||
        image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
 | 
			
		||||
        # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
 | 
			
		||||
        image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)        
 | 
			
		||||
        return_dict.update({
 | 
			
		||||
            "img_msk_latents": image_mask_latents,
 | 
			
		||||
            "img_msk_rebuilt": image_mask_rebuilt,
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    img = get_noise(
 | 
			
		||||
        bs,
 | 
			
		||||
        target_height,
 | 
			
		||||
@ -265,6 +281,9 @@ def denoise(
 | 
			
		||||
    loras_slists=None,
 | 
			
		||||
    unpack_latent = None,
 | 
			
		||||
    joint_pass= False,
 | 
			
		||||
    img_msk_latents = None,
 | 
			
		||||
    img_msk_rebuilt = None,
 | 
			
		||||
    denoising_strength = 1,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
 | 
			
		||||
@ -272,19 +291,38 @@ def denoise(
 | 
			
		||||
    if callback != None:
 | 
			
		||||
        callback(-1, None, True)
 | 
			
		||||
 | 
			
		||||
    original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() 
 | 
			
		||||
    original_timesteps = timesteps
 | 
			
		||||
    morph, first_step = False, 0
 | 
			
		||||
    if img_msk_latents is not None:
 | 
			
		||||
        randn = torch.randn_like(original_image_latents)
 | 
			
		||||
        if denoising_strength < 1.:
 | 
			
		||||
            first_step = int(len(timesteps) * (1. - denoising_strength))
 | 
			
		||||
        if not morph:
 | 
			
		||||
            latent_noise_factor = timesteps[first_step]
 | 
			
		||||
            latents  = original_image_latents  * (1.0 - latent_noise_factor) + randn * latent_noise_factor
 | 
			
		||||
            img = latents.to(img)
 | 
			
		||||
            latents = None
 | 
			
		||||
            timesteps = timesteps[first_step:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    updated_num_steps= len(timesteps) -1
 | 
			
		||||
    if callback != None:
 | 
			
		||||
        from shared.utils.loras_mutipliers import update_loras_slists
 | 
			
		||||
        update_loras_slists(model, loras_slists, updated_num_steps)
 | 
			
		||||
        update_loras_slists(model, loras_slists, len(original_timesteps))
 | 
			
		||||
        callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
			
		||||
    from mmgp import offload
 | 
			
		||||
    # this is ignored for schnell
 | 
			
		||||
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
 | 
			
		||||
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
 | 
			
		||||
        offload.set_step_no_for_lora(model, i)
 | 
			
		||||
        offload.set_step_no_for_lora(model, first_step  + i)
 | 
			
		||||
        if pipeline._interrupt:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph:
 | 
			
		||||
            latent_noise_factor = t_curr/1000
 | 
			
		||||
            img  = original_image_latents  * (1.0 - latent_noise_factor) + img * latent_noise_factor 
 | 
			
		||||
 | 
			
		||||
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
 | 
			
		||||
        img_input = img
 | 
			
		||||
        img_input_ids = img_ids
 | 
			
		||||
@ -334,6 +372,14 @@ def denoise(
 | 
			
		||||
            pred = neg_pred + real_guidance_scale * (pred - neg_pred)
 | 
			
		||||
 | 
			
		||||
        img += (t_prev - t_curr) * pred
 | 
			
		||||
 | 
			
		||||
        if img_msk_latents is not None:
 | 
			
		||||
            latent_noise_factor = t_prev
 | 
			
		||||
            # noisy_image  = original_image_latents  * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor 
 | 
			
		||||
            noisy_image  = original_image_latents  * (1.0 - latent_noise_factor) + randn * latent_noise_factor 
 | 
			
		||||
            img  =  noisy_image * (1-img_msk_latents)  + img_msk_latents * img
 | 
			
		||||
            noisy_image = None
 | 
			
		||||
 | 
			
		||||
        if callback is not None:
 | 
			
		||||
            preview = unpack_latent(img).transpose(0,1)
 | 
			
		||||
            callback(i, preview, False)         
 | 
			
		||||
 | 
			
		||||
@ -640,6 +640,38 @@ configs = {
 | 
			
		||||
            shift_factor=0.1159,
 | 
			
		||||
        ),
 | 
			
		||||
    ),
 | 
			
		||||
    "flux-dev-umo": ModelSpec(
 | 
			
		||||
        repo_id="",
 | 
			
		||||
        repo_flow="",
 | 
			
		||||
        repo_ae="ckpts/flux_vae.safetensors",
 | 
			
		||||
        params=FluxParams(
 | 
			
		||||
            in_channels=64,
 | 
			
		||||
            out_channels=64,
 | 
			
		||||
            vec_in_dim=768,
 | 
			
		||||
            context_in_dim=4096,
 | 
			
		||||
            hidden_size=3072,
 | 
			
		||||
            mlp_ratio=4.0,
 | 
			
		||||
            num_heads=24,
 | 
			
		||||
            depth=19,
 | 
			
		||||
            depth_single_blocks=38,
 | 
			
		||||
            axes_dim=[16, 56, 56],
 | 
			
		||||
            theta=10_000,
 | 
			
		||||
            qkv_bias=True,
 | 
			
		||||
            guidance_embed=True,
 | 
			
		||||
            eso= True,
 | 
			
		||||
        ),
 | 
			
		||||
        ae_params=AutoEncoderParams(
 | 
			
		||||
            resolution=256,
 | 
			
		||||
            in_channels=3,
 | 
			
		||||
            ch=128,
 | 
			
		||||
            out_ch=3,
 | 
			
		||||
            ch_mult=[1, 2, 4, 4],
 | 
			
		||||
            num_res_blocks=2,
 | 
			
		||||
            z_channels=16,
 | 
			
		||||
            scale_factor=0.3611,
 | 
			
		||||
            shift_factor=0.1159,
 | 
			
		||||
        ),
 | 
			
		||||
    ),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -861,16 +861,11 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.avatar:
 | 
			
		||||
                w, h = input_ref_images.size
 | 
			
		||||
                target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
 | 
			
		||||
                if target_width != w or target_height != h:
 | 
			
		||||
                    input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
 | 
			
		||||
                concat_dict = {'mode': 'timecat', 'bias': -1} 
 | 
			
		||||
                freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
 | 
			
		||||
            else:
 | 
			
		||||
                if input_frames != None:
 | 
			
		||||
                    target_height, target_width = input_frames.shape[-3:-1]
 | 
			
		||||
                    target_height, target_width = input_frames.shape[-2:]
 | 
			
		||||
                elif input_video != None:
 | 
			
		||||
                    target_height, target_width = input_video.shape[-2:]
 | 
			
		||||
 | 
			
		||||
@ -899,9 +894,10 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            pixel_value_bg = input_video.unsqueeze(0)
 | 
			
		||||
            pixel_value_mask =  torch.zeros_like(input_video).unsqueeze(0)
 | 
			
		||||
        if input_frames != None:
 | 
			
		||||
            pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
 | 
			
		||||
            pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
 | 
			
		||||
            # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
 | 
			
		||||
            if input_video != None:
 | 
			
		||||
                pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
 | 
			
		||||
@ -913,10 +909,11 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            if pixel_value_bg.shape[2] < frame_num:
 | 
			
		||||
                padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] +  list(pixel_value_bg.shape[3:])  
 | 
			
		||||
                pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
                # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
 | 
			
		||||
            bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()                
 | 
			
		||||
            pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)             
 | 
			
		||||
            pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.)    # unmasked pixels is -1 (no 0 as usual) and masked is 1 
 | 
			
		||||
            mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
 | 
			
		||||
            bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
 | 
			
		||||
            bg_latents.mul_(self.vae.config.scaling_factor)
 | 
			
		||||
 | 
			
		||||
@ -51,11 +51,38 @@ class family_handler():
 | 
			
		||||
        extra_model_def["tea_cache"] = True
 | 
			
		||||
        extra_model_def["mag_cache"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
 | 
			
		||||
        if base_model_type in ["hunyuan_custom_edit"]:
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                "selection": ["MV", "PV"],
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]:
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                "selection": ["A", "NA"],
 | 
			
		||||
                "default" : "NA"
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [("Reference Image", "I")],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
                "visible": False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_avatar"]: 
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [("Start Image", "KI")],
 | 
			
		||||
                "letters_filter":"KI",
 | 
			
		||||
                "visible": False,
 | 
			
		||||
            }
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
 | 
			
		||||
            extra_model_def["one_image_ref_needed"] = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_i2v"]:
 | 
			
		||||
            extra_model_def["image_prompt_types_allowed"] = "S"
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -102,7 +129,7 @@ class family_handler():
 | 
			
		||||
        } 
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type = None,  base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
 | 
			
		||||
    def load_model(model_filename, model_type = None,  base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
 | 
			
		||||
        from .hunyuan import HunyuanVideoSampler
 | 
			
		||||
        from mmgp import offload
 | 
			
		||||
 | 
			
		||||
@ -137,6 +164,24 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
        return hunyuan_model, pipe
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
 | 
			
		||||
        if settings_version<2.33:
 | 
			
		||||
            if base_model_type in ["hunyuan_custom_edit"]:
 | 
			
		||||
                video_prompt_type=  ui_defaults["video_prompt_type"]
 | 
			
		||||
                if "P" in video_prompt_type and "M" in video_prompt_type: 
 | 
			
		||||
                    video_prompt_type = video_prompt_type.replace("M","")
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type  
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.36:
 | 
			
		||||
            if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
 | 
			
		||||
                audio_prompt_type=  ui_defaults["audio_prompt_type"]
 | 
			
		||||
                if "A" not in audio_prompt_type:
 | 
			
		||||
                    audio_prompt_type += "A"
 | 
			
		||||
                    ui_defaults["audio_prompt_type"] = audio_prompt_type  
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults["embedded_guidance_scale"]= 6.0
 | 
			
		||||
@ -158,6 +203,7 @@ class family_handler():
 | 
			
		||||
                "guidance_scale": 7.5,
 | 
			
		||||
                "flow_shift": 13,
 | 
			
		||||
                "video_prompt_type": "I",
 | 
			
		||||
                "audio_prompt_type": "A",
 | 
			
		||||
            })
 | 
			
		||||
        elif base_model_type in ["hunyuan_custom_edit"]:
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
@ -174,4 +220,5 @@ class family_handler():
 | 
			
		||||
                "skip_steps_start_step_perc": 25, 
 | 
			
		||||
                "video_length": 129,
 | 
			
		||||
                "video_prompt_type": "KI",
 | 
			
		||||
                "audio_prompt_type": "A",
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache
 | 
			
		||||
# @lru_cache
 | 
			
		||||
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
 | 
			
		||||
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
 | 
			
		||||
    return block_mask
 | 
			
		||||
 | 
			
		||||
@ -300,9 +300,6 @@ class LTXV:
 | 
			
		||||
            prefix_size, height, width = input_video.shape[-3:]
 | 
			
		||||
        else:
 | 
			
		||||
            if image_start != None:
 | 
			
		||||
                frame_width, frame_height  = image_start.size
 | 
			
		||||
                if fit_into_canvas != None:
 | 
			
		||||
                    height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
 | 
			
		||||
                conditioning_media_paths.append(image_start.unsqueeze(1)) 
 | 
			
		||||
                conditioning_start_frames.append(0)
 | 
			
		||||
                conditioning_control_frames.append(False)
 | 
			
		||||
@ -479,14 +476,14 @@ class LTXV:
 | 
			
		||||
        images = images.sub_(0.5).mul_(2).squeeze(0)
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
    def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs):
 | 
			
		||||
    def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs):
 | 
			
		||||
        map = {
 | 
			
		||||
            "P" : "pose",
 | 
			
		||||
            "D" : "depth",
 | 
			
		||||
            "E" : "canny",
 | 
			
		||||
        }
 | 
			
		||||
        loras = []
 | 
			
		||||
        preloadURLs = get_model_recursive_prop(self.model_type,  "preload_URLs")
 | 
			
		||||
        preloadURLs = get_model_recursive_prop(model_type,  "preload_URLs")
 | 
			
		||||
        lora_file_name = ""
 | 
			
		||||
        for letter, signature in map.items():
 | 
			
		||||
            if letter in video_prompt_type:
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,19 @@ class family_handler():
 | 
			
		||||
        extra_model_def["frames_minimum"] = 17
 | 
			
		||||
        extra_model_def["frames_steps"] = 8
 | 
			
		||||
        extra_model_def["sliding_window"] = True
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = "TSEV"
 | 
			
		||||
 | 
			
		||||
        extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
            "selection": ["", "PV", "DV", "EV", "V"],
 | 
			
		||||
            "labels" : { "V": "Use LTXV raw format"}
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
            "selection": ["", "A", "NA", "XA", "XNA"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["extra_control_frames"] = 1
 | 
			
		||||
        extra_model_def["dont_cat_preguide"]= True
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -64,7 +76,7 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
 | 
			
		||||
        from .ltxv import LTXV
 | 
			
		||||
 | 
			
		||||
        ltxv_model = LTXV(
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,6 @@
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
import inspect
 | 
			
		||||
from typing import Any, Callable, Dict, List, Optional, Union
 | 
			
		||||
@ -28,7 +27,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
 | 
			
		||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
 | 
			
		||||
from diffusers import FlowMatchEulerDiscreteScheduler
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
 | 
			
		||||
 | 
			
		||||
XLA_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
@ -201,7 +200,8 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
 | 
			
		||||
        self.tokenizer_max_length = 1024
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
            # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
            self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
            self.prompt_template_encode_start_idx = 64
 | 
			
		||||
        else:
 | 
			
		||||
            self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
@ -233,6 +233,21 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        txt = [template.format(e) for e in prompt]
 | 
			
		||||
 | 
			
		||||
        if self.processor is not None and image is not None:
 | 
			
		||||
            img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
 | 
			
		||||
            if isinstance(image, list):
 | 
			
		||||
                base_img_prompt = ""
 | 
			
		||||
                for i, img in enumerate(image):
 | 
			
		||||
                    base_img_prompt += img_prompt_template.format(i + 1)
 | 
			
		||||
            elif image is not None:
 | 
			
		||||
                base_img_prompt = img_prompt_template.format(1)
 | 
			
		||||
            else:
 | 
			
		||||
                base_img_prompt = ""
 | 
			
		||||
 | 
			
		||||
            template = self.prompt_template_encode
 | 
			
		||||
 | 
			
		||||
            drop_idx = self.prompt_template_encode_start_idx
 | 
			
		||||
            txt = [template.format(base_img_prompt + e) for e in prompt]
 | 
			
		||||
 | 
			
		||||
            model_inputs = self.processor(
 | 
			
		||||
                text=txt,
 | 
			
		||||
                images=image,
 | 
			
		||||
@ -387,7 +402,8 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        return latent_image_ids.to(device=device, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _pack_latents(latents, batch_size, num_channels_latents, height, width):
 | 
			
		||||
    def _pack_latents(latents):
 | 
			
		||||
        batch_size, num_channels_latents, _, height, width = latents.shape 
 | 
			
		||||
        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
 | 
			
		||||
        latents = latents.permute(0, 2, 4, 1, 3, 5)
 | 
			
		||||
        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
 | 
			
		||||
@ -464,7 +480,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
 | 
			
		||||
    def prepare_latents(
 | 
			
		||||
        self,
 | 
			
		||||
        image,
 | 
			
		||||
        images,
 | 
			
		||||
        batch_size,
 | 
			
		||||
        num_channels_latents,
 | 
			
		||||
        height,
 | 
			
		||||
@ -479,30 +495,33 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        height = 2 * (int(height) // (self.vae_scale_factor * 2))
 | 
			
		||||
        width = 2 * (int(width) // (self.vae_scale_factor * 2))
 | 
			
		||||
 | 
			
		||||
        shape = (batch_size, 1, num_channels_latents, height, width)
 | 
			
		||||
        shape = (batch_size, num_channels_latents, 1, height, width)
 | 
			
		||||
 | 
			
		||||
        image_latents = None
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            image = image.to(device=device, dtype=dtype)
 | 
			
		||||
            if image.shape[1] != self.latent_channels:
 | 
			
		||||
                image_latents = self._encode_vae_image(image=image, generator=generator)
 | 
			
		||||
            else:
 | 
			
		||||
                image_latents = image
 | 
			
		||||
            if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
 | 
			
		||||
                # expand init_latents for batch_size
 | 
			
		||||
                additional_image_per_prompt = batch_size // image_latents.shape[0]
 | 
			
		||||
                image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
 | 
			
		||||
            elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                image_latents = torch.cat([image_latents], dim=0)
 | 
			
		||||
        if images is not None and len(images ) > 0:
 | 
			
		||||
            if not isinstance(images, list):
 | 
			
		||||
                images = [images]
 | 
			
		||||
            all_image_latents = []
 | 
			
		||||
            for image in images:
 | 
			
		||||
                image = image.to(device=device, dtype=dtype)
 | 
			
		||||
                if image.shape[1] != self.latent_channels:
 | 
			
		||||
                    image_latents = self._encode_vae_image(image=image, generator=generator)
 | 
			
		||||
                else:
 | 
			
		||||
                    image_latents = image
 | 
			
		||||
                if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
 | 
			
		||||
                    # expand init_latents for batch_size
 | 
			
		||||
                    additional_image_per_prompt = batch_size // image_latents.shape[0]
 | 
			
		||||
                    image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
 | 
			
		||||
                elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    image_latents = torch.cat([image_latents], dim=0)
 | 
			
		||||
 | 
			
		||||
            image_latent_height, image_latent_width = image_latents.shape[3:]
 | 
			
		||||
            image_latents = self._pack_latents(
 | 
			
		||||
                image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
 | 
			
		||||
            )
 | 
			
		||||
                image_latents = self._pack_latents(image_latents)
 | 
			
		||||
                all_image_latents.append(image_latents)
 | 
			
		||||
            image_latents = torch.cat(all_image_latents, dim=1)
 | 
			
		||||
 | 
			
		||||
        if isinstance(generator, list) and len(generator) != batch_size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -511,7 +530,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            )
 | 
			
		||||
        if latents is None:
 | 
			
		||||
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 | 
			
		||||
            latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
 | 
			
		||||
            latents = self._pack_latents(latents)
 | 
			
		||||
        else:
 | 
			
		||||
            latents = latents.to(device=device, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
@ -563,10 +582,15 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
 | 
			
		||||
        max_sequence_length: int = 512,
 | 
			
		||||
        image = None,
 | 
			
		||||
        image_mask = None,
 | 
			
		||||
        denoising_strength = 0,
 | 
			
		||||
        callback=None,
 | 
			
		||||
        pipeline=None,
 | 
			
		||||
        loras_slists=None,
 | 
			
		||||
        joint_pass= True,
 | 
			
		||||
        lora_inpaint = False,
 | 
			
		||||
        outpainting_dims = None,
 | 
			
		||||
        qwen_edit_plus = False,
 | 
			
		||||
    ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Function invoked when calling the pipeline for generation.
 | 
			
		||||
@ -682,33 +706,54 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            batch_size = prompt_embeds.shape[0]
 | 
			
		||||
        device = "cuda"
 | 
			
		||||
 | 
			
		||||
        prompt_image = None
 | 
			
		||||
        if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
 | 
			
		||||
            image = image[0] if isinstance(image, list) else image
 | 
			
		||||
            image_height, image_width = self.image_processor.get_default_height_width(image)
 | 
			
		||||
            aspect_ratio = image_width / image_height
 | 
			
		||||
            if False :
 | 
			
		||||
                _, image_width, image_height = min(
 | 
			
		||||
                    (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
 | 
			
		||||
                )
 | 
			
		||||
            image_width = image_width // multiple_of * multiple_of
 | 
			
		||||
            image_height = image_height // multiple_of * multiple_of
 | 
			
		||||
            ref_height, ref_width = 1568, 672
 | 
			
		||||
            if height * width < ref_height * ref_width: ref_height , ref_width = height , width  
 | 
			
		||||
            if image_height * image_width > ref_height * ref_width:
 | 
			
		||||
                image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
        condition_images = []
 | 
			
		||||
        vae_image_sizes = []
 | 
			
		||||
        vae_images = []
 | 
			
		||||
        image_mask_latents = None
 | 
			
		||||
        ref_size = 1024
 | 
			
		||||
        ref_text_encoder_size = 384 if qwen_edit_plus else 1024
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            if not isinstance(image, list): image = [image]
 | 
			
		||||
            if height * width < ref_size * ref_size: ref_size =  round(math.sqrt(height * width))  
 | 
			
		||||
            for ref_no, img in enumerate(image):
 | 
			
		||||
                image_width, image_height = img.size
 | 
			
		||||
                any_mask = ref_no == 0 and image_mask is not None
 | 
			
		||||
                if (image_height * image_width > ref_size * ref_size) and not any_mask:
 | 
			
		||||
                    vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
                else:
 | 
			
		||||
                    vae_height, vae_width = image_height, image_width 
 | 
			
		||||
                    vae_width = vae_width // multiple_of * multiple_of
 | 
			
		||||
                    vae_height = vae_height // multiple_of * multiple_of
 | 
			
		||||
                vae_image_sizes.append((vae_width, vae_height))
 | 
			
		||||
                condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
                condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) )
 | 
			
		||||
                if img.size != (vae_width, vae_height):
 | 
			
		||||
                    img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                if any_mask :
 | 
			
		||||
                    if lora_inpaint:
 | 
			
		||||
                        image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
 | 
			
		||||
                        img = convert_image_to_tensor(img)
 | 
			
		||||
                        green = torch.tensor([-1.0, 1.0, -1.0]).to(img) 
 | 
			
		||||
                        green_image = green[:, None, None] .expand_as(img)
 | 
			
		||||
                        img = torch.where(image_mask_rebuilt > 0, green_image, img)
 | 
			
		||||
                        img = convert_tensor_to_image(img)
 | 
			
		||||
                    else:
 | 
			
		||||
                        image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS))
 | 
			
		||||
                        image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
 | 
			
		||||
                        image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
 | 
			
		||||
                        # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
 | 
			
		||||
                        image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
 | 
			
		||||
                        image_mask_latents = self._pack_latents(image_mask_latents)
 | 
			
		||||
                # img.save("nnn.png")
 | 
			
		||||
                vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) )
 | 
			
		||||
 | 
			
		||||
            image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            prompt_image = image
 | 
			
		||||
            image = self.image_processor.preprocess(image, image_height, image_width)
 | 
			
		||||
            image = image.unsqueeze(2)
 | 
			
		||||
 | 
			
		||||
        has_neg_prompt = negative_prompt is not None or (
 | 
			
		||||
            negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
 | 
			
		||||
        )
 | 
			
		||||
        do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
 | 
			
		||||
        prompt_embeds, prompt_embeds_mask = self.encode_prompt(
 | 
			
		||||
            image=prompt_image,
 | 
			
		||||
            image=condition_images,
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            prompt_embeds=prompt_embeds,
 | 
			
		||||
            prompt_embeds_mask=prompt_embeds_mask,
 | 
			
		||||
@ -718,7 +763,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        )
 | 
			
		||||
        if do_true_cfg:
 | 
			
		||||
            negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
 | 
			
		||||
                image=prompt_image,
 | 
			
		||||
                image=condition_images,
 | 
			
		||||
                prompt=negative_prompt,
 | 
			
		||||
                prompt_embeds=negative_prompt_embeds,
 | 
			
		||||
                prompt_embeds_mask=negative_prompt_embeds_mask,
 | 
			
		||||
@ -734,7 +779,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        # 4. Prepare latent variables
 | 
			
		||||
        num_channels_latents = self.transformer.in_channels // 4
 | 
			
		||||
        latents, image_latents = self.prepare_latents(
 | 
			
		||||
            image,
 | 
			
		||||
            vae_images,
 | 
			
		||||
            batch_size * num_images_per_prompt,
 | 
			
		||||
            num_channels_latents,
 | 
			
		||||
            height,
 | 
			
		||||
@ -744,11 +789,18 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            generator,
 | 
			
		||||
            latents,
 | 
			
		||||
        )
 | 
			
		||||
        original_image_latents = None if image_latents is None else image_latents.clone() 
 | 
			
		||||
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            img_shapes = [
 | 
			
		||||
                [
 | 
			
		||||
                    (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
 | 
			
		||||
                    (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
 | 
			
		||||
                    # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
 | 
			
		||||
                    *[
 | 
			
		||||
                        (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
 | 
			
		||||
                        for vae_width, vae_height in vae_image_sizes
 | 
			
		||||
                    ],
 | 
			
		||||
 | 
			
		||||
                ]
 | 
			
		||||
            ] * batch_size
 | 
			
		||||
        else:
 | 
			
		||||
@ -773,7 +825,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        )
 | 
			
		||||
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 | 
			
		||||
        self._num_timesteps = len(timesteps)
 | 
			
		||||
 | 
			
		||||
        original_timesteps = timesteps
 | 
			
		||||
        # handle guidance
 | 
			
		||||
        if self.transformer.guidance_embeds:
 | 
			
		||||
            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
 | 
			
		||||
@ -788,56 +840,80 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        negative_txt_seq_lens = (
 | 
			
		||||
            negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        morph, first_step = False, 0
 | 
			
		||||
        lanpaint_proc = None
 | 
			
		||||
        if image_mask_latents is not None:
 | 
			
		||||
            randn = torch.randn_like(original_image_latents)
 | 
			
		||||
            if denoising_strength < 1.:
 | 
			
		||||
                first_step = int(len(timesteps) * (1. - denoising_strength))
 | 
			
		||||
            if not morph:
 | 
			
		||||
                latent_noise_factor = timesteps[first_step]/1000
 | 
			
		||||
                # latents  = original_image_latents  * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor 
 | 
			
		||||
                latents  = original_image_latents  * (1.0 - latent_noise_factor) + randn * latent_noise_factor 
 | 
			
		||||
                timesteps = timesteps[first_step:]
 | 
			
		||||
                self.scheduler.timesteps = timesteps
 | 
			
		||||
                self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
 | 
			
		||||
            # from shared.inpainting.lanpaint import LanPaint
 | 
			
		||||
            # lanpaint_proc = LanPaint()
 | 
			
		||||
        # 6. Denoising loop
 | 
			
		||||
        self.scheduler.set_begin_index(0)
 | 
			
		||||
        updated_num_steps= len(timesteps)
 | 
			
		||||
        if callback != None:
 | 
			
		||||
            from shared.utils.loras_mutipliers import update_loras_slists
 | 
			
		||||
            update_loras_slists(self.transformer, loras_slists, updated_num_steps)
 | 
			
		||||
            update_loras_slists(self.transformer, loras_slists, len(original_timesteps))
 | 
			
		||||
            callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        for i, t in enumerate(timesteps):
 | 
			
		||||
            offload.set_step_no_for_lora(self.transformer, first_step  + i)
 | 
			
		||||
            if self.interrupt:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            self._current_timestep = t
 | 
			
		||||
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
 | 
			
		||||
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
 | 
			
		||||
 | 
			
		||||
            latent_model_input = latents
 | 
			
		||||
            if image_latents is not None:
 | 
			
		||||
                latent_model_input = torch.cat([latents, image_latents], dim=1)
 | 
			
		||||
            if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
 | 
			
		||||
                latent_noise_factor = t/1000
 | 
			
		||||
                latents  = original_image_latents  * (1.0 - latent_noise_factor) + latents * latent_noise_factor 
 | 
			
		||||
 | 
			
		||||
            if do_true_cfg and joint_pass:
 | 
			
		||||
                noise_pred, neg_noise_pred = self.transformer(
 | 
			
		||||
                    hidden_states=latent_model_input,
 | 
			
		||||
                    timestep=timestep / 1000,
 | 
			
		||||
                    guidance=guidance,
 | 
			
		||||
                    encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
 | 
			
		||||
                    encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
 | 
			
		||||
                    img_shapes=img_shapes,
 | 
			
		||||
                    txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
 | 
			
		||||
                    attention_kwargs=self.attention_kwargs,
 | 
			
		||||
                    **kwargs
 | 
			
		||||
                )
 | 
			
		||||
                if noise_pred == None: return None
 | 
			
		||||
                noise_pred = noise_pred[:, : latents.size(1)]
 | 
			
		||||
                neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
 | 
			
		||||
            else:
 | 
			
		||||
                noise_pred = self.transformer(
 | 
			
		||||
                    hidden_states=latent_model_input,
 | 
			
		||||
                    timestep=timestep / 1000,
 | 
			
		||||
                    guidance=guidance,
 | 
			
		||||
                    encoder_hidden_states_mask_list=[prompt_embeds_mask],
 | 
			
		||||
                    encoder_hidden_states_list=[prompt_embeds],
 | 
			
		||||
                    img_shapes=img_shapes,
 | 
			
		||||
                    txt_seq_lens_list=[txt_seq_lens],
 | 
			
		||||
                    attention_kwargs=self.attention_kwargs,
 | 
			
		||||
                    **kwargs
 | 
			
		||||
                )[0]
 | 
			
		||||
                if noise_pred == None: return None
 | 
			
		||||
                noise_pred = noise_pred[:, : latents.size(1)]
 | 
			
		||||
 | 
			
		||||
            latents_dtype = latents.dtype
 | 
			
		||||
 | 
			
		||||
            # latent_model_input = latents
 | 
			
		||||
            def denoise(latent_model_input, true_cfg_scale):
 | 
			
		||||
                if image_latents is not None:
 | 
			
		||||
                    latent_model_input = torch.cat([latents, image_latents], dim=1)
 | 
			
		||||
                do_true_cfg = true_cfg_scale > 1
 | 
			
		||||
                if do_true_cfg and joint_pass:
 | 
			
		||||
                    noise_pred, neg_noise_pred = self.transformer(
 | 
			
		||||
                        hidden_states=latent_model_input,
 | 
			
		||||
                        timestep=timestep / 1000,
 | 
			
		||||
                        guidance=guidance, #!!!!
 | 
			
		||||
                        encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
 | 
			
		||||
                        encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
 | 
			
		||||
                        img_shapes=img_shapes,
 | 
			
		||||
                        txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
 | 
			
		||||
                        attention_kwargs=self.attention_kwargs,
 | 
			
		||||
                        **kwargs
 | 
			
		||||
                    )
 | 
			
		||||
                    if noise_pred == None: return None, None
 | 
			
		||||
                    noise_pred = noise_pred[:, : latents.size(1)]
 | 
			
		||||
                    neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
 | 
			
		||||
                else:
 | 
			
		||||
                    neg_noise_pred = None
 | 
			
		||||
                    noise_pred = self.transformer(
 | 
			
		||||
                        hidden_states=latent_model_input,
 | 
			
		||||
                        timestep=timestep / 1000,
 | 
			
		||||
                        guidance=guidance,
 | 
			
		||||
                        encoder_hidden_states_mask_list=[prompt_embeds_mask],
 | 
			
		||||
                        encoder_hidden_states_list=[prompt_embeds],
 | 
			
		||||
                        img_shapes=img_shapes,
 | 
			
		||||
                        txt_seq_lens_list=[txt_seq_lens],
 | 
			
		||||
                        attention_kwargs=self.attention_kwargs,
 | 
			
		||||
                        **kwargs
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    if noise_pred == None: return None, None
 | 
			
		||||
                    noise_pred = noise_pred[:, : latents.size(1)]
 | 
			
		||||
 | 
			
		||||
                if do_true_cfg:
 | 
			
		||||
                    neg_noise_pred = self.transformer(
 | 
			
		||||
@ -851,20 +927,43 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
                        attention_kwargs=self.attention_kwargs,
 | 
			
		||||
                        **kwargs
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    if neg_noise_pred == None: return None
 | 
			
		||||
                    if neg_noise_pred == None: return None, None
 | 
			
		||||
                    neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
 | 
			
		||||
                return noise_pred, neg_noise_pred
 | 
			
		||||
            def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
 | 
			
		||||
                if do_true_cfg:
 | 
			
		||||
                    comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
 | 
			
		||||
                    if comb_pred == None: return None
 | 
			
		||||
 | 
			
		||||
            if do_true_cfg:
 | 
			
		||||
                comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
 | 
			
		||||
                if comb_pred == None: return None
 | 
			
		||||
                    cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
 | 
			
		||||
                    noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
 | 
			
		||||
                    noise_pred = comb_pred * (cond_norm / noise_norm)
 | 
			
		||||
 | 
			
		||||
                cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
 | 
			
		||||
                noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
 | 
			
		||||
                noise_pred = comb_pred * (cond_norm / noise_norm)
 | 
			
		||||
                neg_noise_pred = None
 | 
			
		||||
                return noise_pred
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if lanpaint_proc is not None and i<=3:
 | 
			
		||||
                latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
 | 
			
		||||
                if latents is None: return None
 | 
			
		||||
 | 
			
		||||
            noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
 | 
			
		||||
            if noise_pred == None: return None
 | 
			
		||||
            noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
 | 
			
		||||
            neg_noise_pred = None
 | 
			
		||||
            # compute the previous noisy sample x_t -> x_t-1
 | 
			
		||||
            latents_dtype = latents.dtype
 | 
			
		||||
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 | 
			
		||||
            noise_pred = None
 | 
			
		||||
 | 
			
		||||
            if image_mask_latents is not None:
 | 
			
		||||
                if lanpaint_proc is not None:
 | 
			
		||||
                    latents  =  original_image_latents * (1-image_mask_latents)  + image_mask_latents * latents
 | 
			
		||||
                else:
 | 
			
		||||
                    next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
 | 
			
		||||
                    latent_noise_factor = next_t / 1000
 | 
			
		||||
                        # noisy_image  = original_image_latents  * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor 
 | 
			
		||||
                    noisy_image  = original_image_latents  * (1.0 - latent_noise_factor) + randn * latent_noise_factor 
 | 
			
		||||
                    latents  =  noisy_image * (1-image_mask_latents)  + image_mask_latents * latents
 | 
			
		||||
                    noisy_image = None
 | 
			
		||||
 | 
			
		||||
            if latents.dtype != latents_dtype:
 | 
			
		||||
                if torch.backends.mps.is_available():
 | 
			
		||||
@ -872,13 +971,14 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
                    latents = latents.to(latents_dtype)
 | 
			
		||||
 | 
			
		||||
            if callback is not None:
 | 
			
		||||
                # preview = unpack_latent(img).transpose(0,1)
 | 
			
		||||
                callback(i, None, False)         
 | 
			
		||||
                preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
 | 
			
		||||
                preview = preview.squeeze(0)
 | 
			
		||||
                callback(i, preview, False)         
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self._current_timestep = None
 | 
			
		||||
        if output_type == "latent":
 | 
			
		||||
            image = latents
 | 
			
		||||
            output_image = latents
 | 
			
		||||
        else:
 | 
			
		||||
            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
 | 
			
		||||
            latents = latents.to(self.vae.dtype)
 | 
			
		||||
@ -891,7 +991,9 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
                latents.device, latents.dtype
 | 
			
		||||
            )
 | 
			
		||||
            latents = latents / latents_std + latents_mean
 | 
			
		||||
            image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
 | 
			
		||||
            output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
 | 
			
		||||
            if image_mask is not None and not lora_inpaint :  #not (lora_inpaint and outpainting_dims is not None):
 | 
			
		||||
                output_image = vae_images[0].squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(vae_images[0]  ) * image_mask_rebuilt 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return image
 | 
			
		||||
        return output_image
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
import torch
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_qwen_text_encoder_filename(text_encoder_quantization):
 | 
			
		||||
    text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors"
 | 
			
		||||
@ -9,20 +11,51 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
 | 
			
		||||
class family_handler():
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_model_def(base_model_type, model_def):
 | 
			
		||||
        model_def_output = {
 | 
			
		||||
        extra_model_def = {
 | 
			
		||||
            "image_outputs" : True,
 | 
			
		||||
            "sample_solvers":[
 | 
			
		||||
                            ("Default", "default"),
 | 
			
		||||
                            ("Lightning", "lightning")],
 | 
			
		||||
            "guidance_max_phases" : 1,
 | 
			
		||||
            "fit_into_canvas_image_refs": 0,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: 
 | 
			
		||||
            extra_model_def["inpaint_support"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
            "choices": [
 | 
			
		||||
                ("None", ""),
 | 
			
		||||
                ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
 | 
			
		||||
                ("Conditional Images are People / Objects", "I"),
 | 
			
		||||
                ],
 | 
			
		||||
            "letters_filter": "KI",
 | 
			
		||||
            }
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" 
 | 
			
		||||
            extra_model_def["video_guide_outpainting"] = [2]
 | 
			
		||||
            extra_model_def["model_modes"] = {
 | 
			
		||||
                        "choices": [
 | 
			
		||||
                            ("Lora Inpainting: Inpainted area completely unrelated to occulted content", 1),
 | 
			
		||||
                            ("Masked Denoising : Inpainted area may reuse some content that has been occulted", 0),
 | 
			
		||||
                            ],
 | 
			
		||||
                        "default": 1,
 | 
			
		||||
                        "label" : "Inpainting Method",
 | 
			
		||||
                        "image_modes" : [2],
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return model_def_output
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_plus_20B"]: 
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "PV", "SV", "CV"],
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "A"],
 | 
			
		||||
                    "visible": False,
 | 
			
		||||
                }
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_supported_types():
 | 
			
		||||
        return ["qwen_image_20B", "qwen_image_edit_20B"]
 | 
			
		||||
        return ["qwen_image_20B", "qwen_image_edit_20B", "qwen_image_edit_plus_20B"]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_family_maps():
 | 
			
		||||
@ -46,7 +79,7 @@ class family_handler():
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
 | 
			
		||||
        from .qwen_main import model_factory
 | 
			
		||||
        from mmgp import offload
 | 
			
		||||
 | 
			
		||||
@ -74,14 +107,44 @@ class family_handler():
 | 
			
		||||
        if ui_defaults.get("sample_solver", "") == "": 
 | 
			
		||||
            ui_defaults["sample_solver"] = "default"
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.32:
 | 
			
		||||
            ui_defaults["denoising_strength"] = 1.
 | 
			
		||||
                            
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
            "guidance_scale":  4,
 | 
			
		||||
            "sample_solver": "default",
 | 
			
		||||
        })            
 | 
			
		||||
        if model_def.get("reference_image", False):
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B"]: 
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "KI",
 | 
			
		||||
                "denoising_strength" : 1.,
 | 
			
		||||
                "model_mode" : 0,
 | 
			
		||||
            })
 | 
			
		||||
        elif base_model_type in ["qwen_image_edit_plus_20B"]: 
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "I",
 | 
			
		||||
                "denoising_strength" : 1.,
 | 
			
		||||
                "model_mode" : 0,
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def validate_generative_settings(base_model_type, model_def, inputs):
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
 | 
			
		||||
            model_mode = inputs["model_mode"]
 | 
			
		||||
            denoising_strength= inputs["denoising_strength"]
 | 
			
		||||
            video_guide_outpainting= inputs["video_guide_outpainting"]
 | 
			
		||||
            from wgp import get_outpainting_dims
 | 
			
		||||
            outpainting_dims = get_outpainting_dims(video_guide_outpainting)
 | 
			
		||||
 | 
			
		||||
            if denoising_strength < 1 and model_mode == 1:
 | 
			
		||||
                gr.Info("Denoising Strength will be ignored while using Lora Inpainting")
 | 
			
		||||
            if outpainting_dims is not None and model_mode == 0 :
 | 
			
		||||
                return "Outpainting is not supported with Masked Denoising  "
 | 
			
		||||
            
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_rgb_factors(base_model_type ):
 | 
			
		||||
        from shared.RGB_factors import get_rgb_factors
 | 
			
		||||
        latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("qwen")
 | 
			
		||||
        return latent_rgb_factors, latent_rgb_factors_bias
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
 | 
			
		||||
from diffusers import FlowMatchEulerDiscreteScheduler
 | 
			
		||||
from .pipeline_qwenimage import QwenImagePipeline
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
 | 
			
		||||
 | 
			
		||||
def stitch_images(img1, img2):
 | 
			
		||||
    # Resize img2 to match img1's height
 | 
			
		||||
@ -44,17 +44,17 @@ class model_factory():
 | 
			
		||||
        save_quantized = False,
 | 
			
		||||
        dtype = torch.bfloat16,
 | 
			
		||||
        VAE_dtype = torch.float32,
 | 
			
		||||
        mixed_precision_transformer = False
 | 
			
		||||
        mixed_precision_transformer = False,
 | 
			
		||||
    ):
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
        transformer_filename = model_filename[0]
 | 
			
		||||
        processor = None
 | 
			
		||||
        tokenizer = None
 | 
			
		||||
        if base_model_type == "qwen_image_edit_20B":
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
 | 
			
		||||
            processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
 | 
			
		||||
 | 
			
		||||
        self.base_model_type = base_model_type
 | 
			
		||||
 | 
			
		||||
        base_config_file = "configs/qwen_image_20B.json" 
 | 
			
		||||
        with open(base_config_file, 'r', encoding='utf-8') as f:
 | 
			
		||||
@ -103,6 +103,8 @@ class model_factory():
 | 
			
		||||
        n_prompt = None,
 | 
			
		||||
        sampling_steps: int = 20,
 | 
			
		||||
        input_ref_images = None,
 | 
			
		||||
        input_frames= None,
 | 
			
		||||
        input_masks= None,
 | 
			
		||||
        width= 832,
 | 
			
		||||
        height=480,
 | 
			
		||||
        guide_scale: float = 4,
 | 
			
		||||
@ -114,6 +116,9 @@ class model_factory():
 | 
			
		||||
        VAE_tile_size = None, 
 | 
			
		||||
        joint_pass = True,
 | 
			
		||||
        sample_solver='default',
 | 
			
		||||
        denoising_strength = 1.,
 | 
			
		||||
        model_mode = 0,
 | 
			
		||||
        outpainting_dims = None,
 | 
			
		||||
        **bbargs
 | 
			
		||||
    ):
 | 
			
		||||
        # Generate with different aspect ratios
 | 
			
		||||
@ -168,13 +173,17 @@ class model_factory():
 | 
			
		||||
            self.vae.tile_latent_min_height  = VAE_tile_size[1] 
 | 
			
		||||
            self.vae.tile_latent_min_width  = VAE_tile_size[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        qwen_edit_plus = self.base_model_type in ["qwen_image_edit_plus_20B"]
 | 
			
		||||
        self.vae.enable_slicing()
 | 
			
		||||
        # width, height = aspect_ratios["16:9"]
 | 
			
		||||
 | 
			
		||||
        if n_prompt is None or len(n_prompt) == 0:
 | 
			
		||||
            n_prompt=  "text, watermark, copyright, blurry, low resolution"
 | 
			
		||||
 | 
			
		||||
        image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) 
 | 
			
		||||
        if input_frames is not None:
 | 
			
		||||
            input_ref_images = [convert_tensor_to_image(input_frames) ] +  ([] if input_ref_images  is None else input_ref_images )
 | 
			
		||||
 | 
			
		||||
        if input_ref_images is not None:
 | 
			
		||||
            # image stiching method
 | 
			
		||||
            stiched = input_ref_images[0]
 | 
			
		||||
@ -182,14 +191,16 @@ class model_factory():
 | 
			
		||||
                w, h = input_ref_images[0].size
 | 
			
		||||
                height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
			
		||||
 | 
			
		||||
            for new_img in input_ref_images[1:]:
 | 
			
		||||
                stiched = stitch_images(stiched, new_img)
 | 
			
		||||
            input_ref_images  = [stiched]
 | 
			
		||||
            if not qwen_edit_plus:
 | 
			
		||||
                for new_img in input_ref_images[1:]:
 | 
			
		||||
                    stiched = stitch_images(stiched, new_img)
 | 
			
		||||
                input_ref_images  = [stiched]
 | 
			
		||||
 | 
			
		||||
        image = self.pipeline(
 | 
			
		||||
            prompt=input_prompt,
 | 
			
		||||
            negative_prompt=n_prompt,
 | 
			
		||||
            image = input_ref_images,
 | 
			
		||||
            image_mask = image_mask,
 | 
			
		||||
            width=width,
 | 
			
		||||
            height=height,
 | 
			
		||||
            num_inference_steps=sampling_steps,
 | 
			
		||||
@ -199,8 +210,19 @@ class model_factory():
 | 
			
		||||
            pipeline=self,
 | 
			
		||||
            loras_slists=loras_slists,
 | 
			
		||||
            joint_pass = joint_pass,
 | 
			
		||||
            generator=torch.Generator(device="cuda").manual_seed(seed)
 | 
			
		||||
        )        
 | 
			
		||||
            denoising_strength=denoising_strength,
 | 
			
		||||
            generator=torch.Generator(device="cuda").manual_seed(seed),
 | 
			
		||||
            lora_inpaint = image_mask is not None and model_mode == 1,
 | 
			
		||||
            outpainting_dims = outpainting_dims,
 | 
			
		||||
            qwen_edit_plus = qwen_edit_plus,
 | 
			
		||||
        )      
 | 
			
		||||
        if image is None: return None
 | 
			
		||||
        return image.transpose(0, 1)
 | 
			
		||||
 | 
			
		||||
    def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
 | 
			
		||||
        if model_mode == 0: return [], []
 | 
			
		||||
        preloadURLs = get_model_recursive_prop(model_type,  "preload_URLs")
 | 
			
		||||
        if len(preloadURLs) == 0: return [], []
 | 
			
		||||
        return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -204,7 +204,7 @@ class QwenEmbedRope(nn.Module):
 | 
			
		||||
            frame, height, width = fhw
 | 
			
		||||
            rope_key = f"{idx}_{height}_{width}"
 | 
			
		||||
 | 
			
		||||
            if not torch.compiler.is_compiling():
 | 
			
		||||
            if not torch.compiler.is_compiling() and False:
 | 
			
		||||
                if rope_key not in self.rope_cache:
 | 
			
		||||
                    self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
 | 
			
		||||
                video_freq = self.rope_cache[rope_key]
 | 
			
		||||
@ -224,7 +224,6 @@ class QwenEmbedRope(nn.Module):
 | 
			
		||||
 | 
			
		||||
        return vid_freqs, txt_freqs
 | 
			
		||||
 | 
			
		||||
    @functools.lru_cache(maxsize=None)
 | 
			
		||||
    def _compute_video_freqs(self, frame, height, width, idx=0):
 | 
			
		||||
        seq_lens = frame * height * width
 | 
			
		||||
        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										143
									
								
								models/wan/animate/animate_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								models/wan/animate/animate_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,143 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import torch
 | 
			
		||||
import numbers
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
 | 
			
		||||
    target_modules = []
 | 
			
		||||
    for name, module in transformer.named_modules():
 | 
			
		||||
        if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
 | 
			
		||||
            target_modules.append(name)
 | 
			
		||||
 | 
			
		||||
    transformer_lora_config = LoraConfig(
 | 
			
		||||
        r=rank,
 | 
			
		||||
        lora_alpha=alpha,
 | 
			
		||||
        init_lora_weights=init_lora_weights,
 | 
			
		||||
        target_modules=target_modules,
 | 
			
		||||
    )
 | 
			
		||||
    return transformer_lora_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorList(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tensors):
 | 
			
		||||
        """
 | 
			
		||||
        tensors: a list of torch.Tensor objects. No need to have uniform shape.
 | 
			
		||||
        """
 | 
			
		||||
        assert isinstance(tensors, (list, tuple))
 | 
			
		||||
        assert all(isinstance(u, torch.Tensor) for u in tensors)
 | 
			
		||||
        assert len(set([u.ndim for u in tensors])) == 1
 | 
			
		||||
        assert len(set([u.dtype for u in tensors])) == 1
 | 
			
		||||
        assert len(set([u.device for u in tensors])) == 1
 | 
			
		||||
        self.tensors = tensors
 | 
			
		||||
    
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.to(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def size(self, dim):
 | 
			
		||||
        assert dim == 0, 'only support get the 0th size'
 | 
			
		||||
        return len(self.tensors)
 | 
			
		||||
    
 | 
			
		||||
    def pow(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def squeeze(self, dim):
 | 
			
		||||
        assert dim != 0
 | 
			
		||||
        if dim > 0:
 | 
			
		||||
            dim -= 1
 | 
			
		||||
        return TensorList([u.squeeze(dim) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def type(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.type(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def type_as(self, other):
 | 
			
		||||
        assert isinstance(other, (torch.Tensor, TensorList))
 | 
			
		||||
        if isinstance(other, torch.Tensor):
 | 
			
		||||
            return TensorList([u.type_as(other) for u in self.tensors])
 | 
			
		||||
        else:
 | 
			
		||||
            return TensorList([u.type(other.dtype) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def dtype(self):
 | 
			
		||||
        return self.tensors[0].dtype
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def device(self):
 | 
			
		||||
        return self.tensors[0].device
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def ndim(self):
 | 
			
		||||
        return 1 + self.tensors[0].ndim
 | 
			
		||||
    
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        return self.tensors[index]
 | 
			
		||||
    
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.tensors)
 | 
			
		||||
    
 | 
			
		||||
    def __add__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u + v)
 | 
			
		||||
    
 | 
			
		||||
    def __radd__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v + u)
 | 
			
		||||
    
 | 
			
		||||
    def __sub__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u - v)
 | 
			
		||||
    
 | 
			
		||||
    def __rsub__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v - u)
 | 
			
		||||
    
 | 
			
		||||
    def __mul__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u * v)
 | 
			
		||||
    
 | 
			
		||||
    def __rmul__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v * u)
 | 
			
		||||
    
 | 
			
		||||
    def __floordiv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u // v)
 | 
			
		||||
    
 | 
			
		||||
    def __truediv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u / v)
 | 
			
		||||
    
 | 
			
		||||
    def __rfloordiv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v // u)
 | 
			
		||||
    
 | 
			
		||||
    def __rtruediv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v / u)
 | 
			
		||||
    
 | 
			
		||||
    def __pow__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u ** v)
 | 
			
		||||
    
 | 
			
		||||
    def __rpow__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v ** u)
 | 
			
		||||
    
 | 
			
		||||
    def __neg__(self):
 | 
			
		||||
        return TensorList([-u for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for tensor in self.tensors:
 | 
			
		||||
            yield tensor
 | 
			
		||||
    
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return 'TensorList: \n' + repr(self.tensors)
 | 
			
		||||
 | 
			
		||||
    def _apply(self, other, op):
 | 
			
		||||
        if isinstance(other, (list, tuple, TensorList)) or (
 | 
			
		||||
            isinstance(other, torch.Tensor) and (
 | 
			
		||||
                other.numel() > 1 or other.ndim > 1
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            assert len(other) == len(self.tensors)
 | 
			
		||||
            return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
 | 
			
		||||
        elif isinstance(other, numbers.Number) or (
 | 
			
		||||
            isinstance(other, torch.Tensor) and (
 | 
			
		||||
                other.numel() == 1 and other.ndim <= 1
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            return TensorList([op(u, other) for u in self.tensors])
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                f'unsupported operand for *: "TensorList" and "{type(other)}"'
 | 
			
		||||
            )
 | 
			
		||||
							
								
								
									
										382
									
								
								models/wan/animate/face_blocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								models/wan/animate/face_blocks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,382 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
from torch import nn
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Tuple, Optional
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import math
 | 
			
		||||
from shared.attention import pay_attention
 | 
			
		||||
 | 
			
		||||
MEMORY_LAYOUT = {
 | 
			
		||||
    "flash": (
 | 
			
		||||
        lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
 | 
			
		||||
        lambda x: x,
 | 
			
		||||
    ),
 | 
			
		||||
    "torch": (
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
    ),
 | 
			
		||||
    "vanilla": (
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
    ),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    mode="torch",
 | 
			
		||||
    drop_rate=0,
 | 
			
		||||
    attn_mask=None,
 | 
			
		||||
    causal=False,
 | 
			
		||||
    max_seqlen_q=None,
 | 
			
		||||
    batch_size=1,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Perform QKV self attention.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
 | 
			
		||||
        k (torch.Tensor): Key tensor with shape [b, s1, a, d]
 | 
			
		||||
        v (torch.Tensor): Value tensor with shape [b, s1, a, d]
 | 
			
		||||
        mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
 | 
			
		||||
        drop_rate (float): Dropout rate in attention map. (default: 0)
 | 
			
		||||
        attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
 | 
			
		||||
            (default: None)
 | 
			
		||||
        causal (bool): Whether to use causal attention. (default: False)
 | 
			
		||||
        cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
 | 
			
		||||
            used to index into q.
 | 
			
		||||
        cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
 | 
			
		||||
            used to index into kv.
 | 
			
		||||
        max_seqlen_q (int): The maximum sequence length in the batch of q.
 | 
			
		||||
        max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        torch.Tensor: Output tensor after self attention with shape [b, s, ad]
 | 
			
		||||
    """
 | 
			
		||||
    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
 | 
			
		||||
 | 
			
		||||
    if mode == "torch":
 | 
			
		||||
        if attn_mask is not None and attn_mask.dtype != torch.bool:
 | 
			
		||||
            attn_mask = attn_mask.to(q.dtype)
 | 
			
		||||
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
 | 
			
		||||
 | 
			
		||||
    elif mode == "flash":
 | 
			
		||||
        x = flash_attn_func(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
        )
 | 
			
		||||
        x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1])  # reshape x to [b, s, a, d]
 | 
			
		||||
    elif mode == "vanilla":
 | 
			
		||||
        scale_factor = 1 / math.sqrt(q.size(-1))
 | 
			
		||||
 | 
			
		||||
        b, a, s, _ = q.shape
 | 
			
		||||
        s1 = k.size(2)
 | 
			
		||||
        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
 | 
			
		||||
        if causal:
 | 
			
		||||
            # Only applied to self attention
 | 
			
		||||
            assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
 | 
			
		||||
            temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
 | 
			
		||||
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
 | 
			
		||||
            attn_bias.to(q.dtype)
 | 
			
		||||
 | 
			
		||||
        if attn_mask is not None:
 | 
			
		||||
            if attn_mask.dtype == torch.bool:
 | 
			
		||||
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_bias += attn_mask
 | 
			
		||||
 | 
			
		||||
        attn = (q @ k.transpose(-2, -1)) * scale_factor
 | 
			
		||||
        attn += attn_bias
 | 
			
		||||
        attn = attn.softmax(dim=-1)
 | 
			
		||||
        attn = torch.dropout(attn, p=drop_rate, train=True)
 | 
			
		||||
        x = attn @ v
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unsupported attention mode: {mode}")
 | 
			
		||||
 | 
			
		||||
    x = post_attn_layout(x)
 | 
			
		||||
    b, s, a, d = x.shape
 | 
			
		||||
    out = x.reshape(b, s, -1)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CausalConv1d(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.pad_mode = pad_mode
 | 
			
		||||
        padding = (kernel_size - 1, 0)  # T
 | 
			
		||||
        self.time_causal_padding = padding
 | 
			
		||||
 | 
			
		||||
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
 | 
			
		||||
        return self.conv(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceEncoder(nn.Module):
 | 
			
		||||
    def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
 | 
			
		||||
        factory_kwargs = {"dtype": dtype, "device": device}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
 | 
			
		||||
        self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
        self.act = nn.SiLU()
 | 
			
		||||
        self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
 | 
			
		||||
        self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
 | 
			
		||||
 | 
			
		||||
        self.out_proj = nn.Linear(1024, hidden_dim)
 | 
			
		||||
        self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        b, c, t = x.shape
 | 
			
		||||
 | 
			
		||||
        x = self.conv1_local(x)
 | 
			
		||||
        x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
 | 
			
		||||
        
 | 
			
		||||
        x = self.norm1(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        x = self.conv2(x)
 | 
			
		||||
        x = rearrange(x, "b c t -> b t c")
 | 
			
		||||
        x = self.norm2(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        x = self.conv3(x)
 | 
			
		||||
        x = rearrange(x, "b c t -> b t c")
 | 
			
		||||
        x = self.norm3(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = self.out_proj(x)
 | 
			
		||||
        x = rearrange(x, "(b n) t c -> b t n c", b=b)
 | 
			
		||||
        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
 | 
			
		||||
        x = torch.cat([x, padding], dim=-2)
 | 
			
		||||
        x_local = x.clone()
 | 
			
		||||
 | 
			
		||||
        return x_local
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNorm(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        elementwise_affine=True,
 | 
			
		||||
        eps: float = 1e-6,
 | 
			
		||||
        device=None,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the RMSNorm normalization layer.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dim (int): The dimension of the input tensor.
 | 
			
		||||
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
 | 
			
		||||
 | 
			
		||||
        Attributes:
 | 
			
		||||
            eps (float): A small value added to the denominator for numerical stability.
 | 
			
		||||
            weight (nn.Parameter): Learnable scaling parameter.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
 | 
			
		||||
 | 
			
		||||
    def _norm(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        Apply the RMSNorm normalization to the input tensor.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (torch.Tensor): The input tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: The normalized tensor.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        Forward pass through the RMSNorm layer.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (torch.Tensor): The input tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: The output tensor after applying RMSNorm.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        output = self._norm(x.float()).type_as(x)
 | 
			
		||||
        if hasattr(self, "weight"):
 | 
			
		||||
            output = output * self.weight
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_norm_layer(norm_layer):
 | 
			
		||||
    """
 | 
			
		||||
    Get the normalization layer.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        norm_layer (str): The type of normalization layer.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        norm_layer (nn.Module): The normalization layer.
 | 
			
		||||
    """
 | 
			
		||||
    if norm_layer == "layer":
 | 
			
		||||
        return nn.LayerNorm
 | 
			
		||||
    elif norm_layer == "rms":
 | 
			
		||||
        return RMSNorm
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceAdapter(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_dim: int,
 | 
			
		||||
        heads_num: int,
 | 
			
		||||
        qk_norm: bool = True,
 | 
			
		||||
        qk_norm_type: str = "rms",
 | 
			
		||||
        num_adapter_layers: int = 1,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        factory_kwargs = {"dtype": dtype, "device": device}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_dim
 | 
			
		||||
        self.heads_num = heads_num
 | 
			
		||||
        self.fuser_blocks = nn.ModuleList(
 | 
			
		||||
            [
 | 
			
		||||
                FaceBlock(
 | 
			
		||||
                    self.hidden_size,
 | 
			
		||||
                    self.heads_num,
 | 
			
		||||
                    qk_norm=qk_norm,
 | 
			
		||||
                    qk_norm_type=qk_norm_type,
 | 
			
		||||
                    **factory_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
                for _ in range(num_adapter_layers)
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        motion_embed: torch.Tensor,
 | 
			
		||||
        idx: int,
 | 
			
		||||
        freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
 | 
			
		||||
        freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceBlock(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        heads_num: int,
 | 
			
		||||
        qk_norm: bool = True,
 | 
			
		||||
        qk_norm_type: str = "rms",
 | 
			
		||||
        qk_scale: float = None,
 | 
			
		||||
        dtype: Optional[torch.dtype] = None,
 | 
			
		||||
        device: Optional[torch.device] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.deterministic = False
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.heads_num = heads_num
 | 
			
		||||
        head_dim = hidden_size // heads_num
 | 
			
		||||
        self.scale = qk_scale or head_dim**-0.5
 | 
			
		||||
       
 | 
			
		||||
        self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
 | 
			
		||||
        self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        qk_norm_layer = get_norm_layer(qk_norm_type)
 | 
			
		||||
        self.q_norm = (
 | 
			
		||||
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
 | 
			
		||||
        )
 | 
			
		||||
        self.k_norm = (
 | 
			
		||||
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        motion_vec: torch.Tensor,
 | 
			
		||||
        motion_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_context_parallel=False,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        
 | 
			
		||||
        B, T, N, C = motion_vec.shape
 | 
			
		||||
        T_comp = T
 | 
			
		||||
 | 
			
		||||
        x_motion = self.pre_norm_motion(motion_vec)
 | 
			
		||||
        x_feat = self.pre_norm_feat(x)
 | 
			
		||||
 | 
			
		||||
        kv = self.linear1_kv(x_motion)
 | 
			
		||||
        q = self.linear1_q(x_feat)
 | 
			
		||||
 | 
			
		||||
        k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
 | 
			
		||||
        q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
 | 
			
		||||
 | 
			
		||||
        # Apply QK-Norm if needed.
 | 
			
		||||
        q = self.q_norm(q).to(v)
 | 
			
		||||
        k = self.k_norm(k).to(v)
 | 
			
		||||
 | 
			
		||||
        k = rearrange(k, "B L N H D -> (B L) N H D")  
 | 
			
		||||
        v = rearrange(v, "B L N H D -> (B L) N H D") 
 | 
			
		||||
 | 
			
		||||
        if use_context_parallel:
 | 
			
		||||
            q = gather_forward(q, dim=1)
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)  
 | 
			
		||||
        # Compute attention.
 | 
			
		||||
    # Size([batches, tokens, heads, head_features])
 | 
			
		||||
        qkv_list = [q, k, v]
 | 
			
		||||
        del q,k,v
 | 
			
		||||
        attn = pay_attention(qkv_list)
 | 
			
		||||
        # attn = attention(
 | 
			
		||||
        #     q,
 | 
			
		||||
        #     k,
 | 
			
		||||
        #     v,
 | 
			
		||||
        #     max_seqlen_q=q.shape[1],
 | 
			
		||||
        #     batch_size=q.shape[0],
 | 
			
		||||
        # )
 | 
			
		||||
 | 
			
		||||
        attn = attn.reshape(*attn.shape[:2], -1)
 | 
			
		||||
        attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
 | 
			
		||||
        # if use_context_parallel:
 | 
			
		||||
        #     attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
 | 
			
		||||
 | 
			
		||||
        output = self.linear2(attn)
 | 
			
		||||
 | 
			
		||||
        if motion_mask is not None:
 | 
			
		||||
            output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        return output
 | 
			
		||||
							
								
								
									
										31
									
								
								models/wan/animate/model_animate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								models/wan/animate/model_animate.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import math
 | 
			
		||||
import types
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from einops import  rearrange
 | 
			
		||||
from typing import List
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.cuda.amp as amp
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
 | 
			
		||||
    pose_latents = self.pose_patch_embedding(pose_latents)
 | 
			
		||||
    x[:, :, 1:] += pose_latents
 | 
			
		||||
    
 | 
			
		||||
    b,c,T,h,w = face_pixel_values.shape
 | 
			
		||||
    face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
 | 
			
		||||
    encode_bs = 8
 | 
			
		||||
    face_pixel_values_tmp = []
 | 
			
		||||
    for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
 | 
			
		||||
        face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
 | 
			
		||||
 | 
			
		||||
    motion_vec = torch.cat(face_pixel_values_tmp)
 | 
			
		||||
    
 | 
			
		||||
    motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
 | 
			
		||||
    motion_vec = self.face_encoder(motion_vec)
 | 
			
		||||
 | 
			
		||||
    B, L, H, C = motion_vec.shape
 | 
			
		||||
    pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
 | 
			
		||||
    motion_vec = torch.cat([pad_face, motion_vec], dim=1)
 | 
			
		||||
    return x, motion_vec
 | 
			
		||||
							
								
								
									
										308
									
								
								models/wan/animate/motion_encoder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								models/wan/animate/motion_encoder.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,308 @@
 | 
			
		||||
# Modified from ``https://github.com/wyhsirius/LIA``
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
def custom_qr(input_tensor):
 | 
			
		||||
    original_dtype = input_tensor.dtype
 | 
			
		||||
    if original_dtype in [torch.bfloat16, torch.float16]:
 | 
			
		||||
        q, r = torch.linalg.qr(input_tensor.to(torch.float32))
 | 
			
		||||
        return q.to(original_dtype), r.to(original_dtype)
 | 
			
		||||
    return torch.linalg.qr(input_tensor)
 | 
			
		||||
 | 
			
		||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
 | 
			
		||||
	return F.leaky_relu(input + bias, negative_slope) * scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
 | 
			
		||||
	_, minor, in_h, in_w = input.shape
 | 
			
		||||
	kernel_h, kernel_w = kernel.shape
 | 
			
		||||
 | 
			
		||||
	out = input.view(-1, minor, in_h, 1, in_w, 1)
 | 
			
		||||
	out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
 | 
			
		||||
	out = out.view(-1, minor, in_h * up_y, in_w * up_x)
 | 
			
		||||
 | 
			
		||||
	out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
 | 
			
		||||
	out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
 | 
			
		||||
		  max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
 | 
			
		||||
 | 
			
		||||
	out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
 | 
			
		||||
	w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
 | 
			
		||||
	out = F.conv2d(out, w)
 | 
			
		||||
	out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
 | 
			
		||||
					  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
 | 
			
		||||
	return out[:, :, ::down_y, ::down_x]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
 | 
			
		||||
	return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_kernel(k):
 | 
			
		||||
	k = torch.tensor(k, dtype=torch.float32)
 | 
			
		||||
	if k.ndim == 1:
 | 
			
		||||
		k = k[None, :] * k[:, None]
 | 
			
		||||
	k /= k.sum()
 | 
			
		||||
	return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedLeakyReLU(nn.Module):
 | 
			
		||||
	def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
		self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
 | 
			
		||||
		self.negative_slope = negative_slope
 | 
			
		||||
		self.scale = scale
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Blur(nn.Module):
 | 
			
		||||
	def __init__(self, kernel, pad, upsample_factor=1):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		kernel = make_kernel(kernel)
 | 
			
		||||
 | 
			
		||||
		if upsample_factor > 1:
 | 
			
		||||
			kernel = kernel * (upsample_factor ** 2)
 | 
			
		||||
 | 
			
		||||
		self.register_buffer('kernel', kernel)
 | 
			
		||||
 | 
			
		||||
		self.pad = pad
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		return upfirdn2d(input, self.kernel, pad=self.pad)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScaledLeakyReLU(nn.Module):
 | 
			
		||||
	def __init__(self, negative_slope=0.2):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.negative_slope = negative_slope
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		return F.leaky_relu(input, negative_slope=self.negative_slope)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EqualConv2d(nn.Module):
 | 
			
		||||
	def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
 | 
			
		||||
		self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
 | 
			
		||||
 | 
			
		||||
		self.stride = stride
 | 
			
		||||
		self.padding = padding
 | 
			
		||||
 | 
			
		||||
		if bias:
 | 
			
		||||
			self.bias = nn.Parameter(torch.zeros(out_channel))
 | 
			
		||||
		else:
 | 
			
		||||
			self.bias = None
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
 | 
			
		||||
		return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
 | 
			
		||||
 | 
			
		||||
	def __repr__(self):
 | 
			
		||||
		return (
 | 
			
		||||
			f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
 | 
			
		||||
			f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EqualLinear(nn.Module):
 | 
			
		||||
	def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
 | 
			
		||||
 | 
			
		||||
		if bias:
 | 
			
		||||
			self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
 | 
			
		||||
		else:
 | 
			
		||||
			self.bias = None
 | 
			
		||||
 | 
			
		||||
		self.activation = activation
 | 
			
		||||
 | 
			
		||||
		self.scale = (1 / math.sqrt(in_dim)) * lr_mul
 | 
			
		||||
		self.lr_mul = lr_mul
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
 | 
			
		||||
		if self.activation:
 | 
			
		||||
			out = F.linear(input, self.weight * self.scale)
 | 
			
		||||
			out = fused_leaky_relu(out, self.bias * self.lr_mul)
 | 
			
		||||
		else:
 | 
			
		||||
			out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
 | 
			
		||||
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
	def __repr__(self):
 | 
			
		||||
		return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConvLayer(nn.Sequential):
 | 
			
		||||
	def __init__(
 | 
			
		||||
			self,
 | 
			
		||||
			in_channel,
 | 
			
		||||
			out_channel,
 | 
			
		||||
			kernel_size,
 | 
			
		||||
			downsample=False,
 | 
			
		||||
			blur_kernel=[1, 3, 3, 1],
 | 
			
		||||
			bias=True,
 | 
			
		||||
			activate=True,
 | 
			
		||||
	):
 | 
			
		||||
		layers = []
 | 
			
		||||
 | 
			
		||||
		if downsample:
 | 
			
		||||
			factor = 2
 | 
			
		||||
			p = (len(blur_kernel) - factor) + (kernel_size - 1)
 | 
			
		||||
			pad0 = (p + 1) // 2
 | 
			
		||||
			pad1 = p // 2
 | 
			
		||||
 | 
			
		||||
			layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
 | 
			
		||||
 | 
			
		||||
			stride = 2
 | 
			
		||||
			self.padding = 0
 | 
			
		||||
 | 
			
		||||
		else:
 | 
			
		||||
			stride = 1
 | 
			
		||||
			self.padding = kernel_size // 2
 | 
			
		||||
 | 
			
		||||
		layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
 | 
			
		||||
								  bias=bias and not activate))
 | 
			
		||||
 | 
			
		||||
		if activate:
 | 
			
		||||
			if bias:
 | 
			
		||||
				layers.append(FusedLeakyReLU(out_channel))
 | 
			
		||||
			else:
 | 
			
		||||
				layers.append(ScaledLeakyReLU(0.2))
 | 
			
		||||
 | 
			
		||||
		super().__init__(*layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResBlock(nn.Module):
 | 
			
		||||
	def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.conv1 = ConvLayer(in_channel, in_channel, 3)
 | 
			
		||||
		self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
 | 
			
		||||
 | 
			
		||||
		self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		out = self.conv1(input)
 | 
			
		||||
		out = self.conv2(out)
 | 
			
		||||
 | 
			
		||||
		skip = self.skip(input)
 | 
			
		||||
		out = (out + skip) / math.sqrt(2)
 | 
			
		||||
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EncoderApp(nn.Module):
 | 
			
		||||
	def __init__(self, size, w_dim=512):
 | 
			
		||||
		super(EncoderApp, self).__init__()
 | 
			
		||||
 | 
			
		||||
		channels = {
 | 
			
		||||
			4: 512,
 | 
			
		||||
			8: 512,
 | 
			
		||||
			16: 512,
 | 
			
		||||
			32: 512,
 | 
			
		||||
			64: 256,
 | 
			
		||||
			128: 128,
 | 
			
		||||
			256: 64,
 | 
			
		||||
			512: 32,
 | 
			
		||||
			1024: 16
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		self.w_dim = w_dim
 | 
			
		||||
		log_size = int(math.log(size, 2))
 | 
			
		||||
 | 
			
		||||
		self.convs = nn.ModuleList()
 | 
			
		||||
		self.convs.append(ConvLayer(3, channels[size], 1))
 | 
			
		||||
 | 
			
		||||
		in_channel = channels[size]
 | 
			
		||||
		for i in range(log_size, 2, -1):
 | 
			
		||||
			out_channel = channels[2 ** (i - 1)]
 | 
			
		||||
			self.convs.append(ResBlock(in_channel, out_channel))
 | 
			
		||||
			in_channel = out_channel
 | 
			
		||||
 | 
			
		||||
		self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
 | 
			
		||||
 | 
			
		||||
	def forward(self, x):
 | 
			
		||||
 | 
			
		||||
		res = []
 | 
			
		||||
		h = x
 | 
			
		||||
		for conv in self.convs:
 | 
			
		||||
			h = conv(h)
 | 
			
		||||
			res.append(h)
 | 
			
		||||
 | 
			
		||||
		return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Encoder(nn.Module):
 | 
			
		||||
	def __init__(self, size, dim=512, dim_motion=20):
 | 
			
		||||
		super(Encoder, self).__init__()
 | 
			
		||||
 | 
			
		||||
		# appearance netmork
 | 
			
		||||
		self.net_app = EncoderApp(size, dim)
 | 
			
		||||
 | 
			
		||||
		# motion network
 | 
			
		||||
		fc = [EqualLinear(dim, dim)]
 | 
			
		||||
		for i in range(3):
 | 
			
		||||
			fc.append(EqualLinear(dim, dim))
 | 
			
		||||
 | 
			
		||||
		fc.append(EqualLinear(dim, dim_motion))
 | 
			
		||||
		self.fc = nn.Sequential(*fc)
 | 
			
		||||
 | 
			
		||||
	def enc_app(self, x):
 | 
			
		||||
		h_source = self.net_app(x)
 | 
			
		||||
		return h_source
 | 
			
		||||
 | 
			
		||||
	def enc_motion(self, x):
 | 
			
		||||
		h, _ = self.net_app(x)
 | 
			
		||||
		h_motion = self.fc(h)
 | 
			
		||||
		return h_motion
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Direction(nn.Module):
 | 
			
		||||
    def __init__(self, motion_dim):
 | 
			
		||||
        super(Direction, self).__init__()
 | 
			
		||||
        self.weight = nn.Parameter(torch.randn(512, motion_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, input):
 | 
			
		||||
 | 
			
		||||
        weight = self.weight + 1e-8
 | 
			
		||||
        Q, R = custom_qr(weight)
 | 
			
		||||
        if input is None:
 | 
			
		||||
            return Q
 | 
			
		||||
        else:
 | 
			
		||||
            input_diag = torch.diag_embed(input)  # alpha, diagonal matrix
 | 
			
		||||
            out = torch.matmul(input_diag, Q.T)
 | 
			
		||||
            out = torch.sum(out, dim=1)
 | 
			
		||||
            return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Synthesis(nn.Module):
 | 
			
		||||
    def __init__(self, motion_dim):
 | 
			
		||||
        super(Synthesis, self).__init__()
 | 
			
		||||
        self.direction = Direction(motion_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Generator(nn.Module):
 | 
			
		||||
	def __init__(self, size, style_dim=512, motion_dim=20):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.enc = Encoder(size, style_dim, motion_dim)
 | 
			
		||||
		self.dec = Synthesis(motion_dim)
 | 
			
		||||
 | 
			
		||||
	def get_motion(self, img):
 | 
			
		||||
		#motion_feat = self.enc.enc_motion(img)
 | 
			
		||||
		# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
 | 
			
		||||
		with torch.cuda.amp.autocast(dtype=torch.float32):
 | 
			
		||||
			motion_feat = self.enc.enc_motion(img)
 | 
			
		||||
			motion = self.dec.direction(motion_feat)
 | 
			
		||||
		return motion
 | 
			
		||||
@ -19,7 +19,8 @@ from PIL import Image
 | 
			
		||||
import torchvision.transforms.functional as TF
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from .distributed.fsdp import shard_model
 | 
			
		||||
from .modules.model import WanModel, clear_caches
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
from .modules.t5 import T5EncoderModel
 | 
			
		||||
from .modules.vae import WanVAE
 | 
			
		||||
from .modules.vae2_2 import Wan2_2_VAE
 | 
			
		||||
@ -31,9 +32,11 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
 | 
			
		||||
from shared.utils.vace_preprocessor import VaceVideoProcessor
 | 
			
		||||
from shared.utils.basic_flowmatch import FlowMatchScheduler
 | 
			
		||||
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
 | 
			
		||||
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
 | 
			
		||||
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
 | 
			
		||||
from shared.utils.audio_video import save_video
 | 
			
		||||
from mmgp import safetensors2
 | 
			
		||||
from shared.utils.audio_video import save_video
 | 
			
		||||
 | 
			
		||||
def optimized_scale(positive_flat, negative_flat):
 | 
			
		||||
 | 
			
		||||
@ -63,6 +66,7 @@ class WanAny2V:
 | 
			
		||||
        config,
 | 
			
		||||
        checkpoint_dir,
 | 
			
		||||
        model_filename = None,
 | 
			
		||||
        submodel_no_list = None,
 | 
			
		||||
        model_type = None, 
 | 
			
		||||
        model_def = None,
 | 
			
		||||
        base_model_type = None,
 | 
			
		||||
@ -91,7 +95,7 @@ class WanAny2V:
 | 
			
		||||
            shard_fn= None)
 | 
			
		||||
 | 
			
		||||
        # base_model_type = "i2v2_2"
 | 
			
		||||
        if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]:
 | 
			
		||||
        if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]:
 | 
			
		||||
            self.clip = CLIPModel(
 | 
			
		||||
                dtype=config.clip_dtype,
 | 
			
		||||
                device=self.device,
 | 
			
		||||
@ -100,7 +104,7 @@ class WanAny2V:
 | 
			
		||||
                tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["ti2v_2_2"]:
 | 
			
		||||
        if base_model_type in ["ti2v_2_2", "lucy_edit"]:
 | 
			
		||||
            self.vae_stride = (4, 16, 16)
 | 
			
		||||
            vae_checkpoint = "Wan2.2_VAE.safetensors"
 | 
			
		||||
            vae = Wan2_2_VAE
 | 
			
		||||
@ -125,75 +129,82 @@ class WanAny2V:
 | 
			
		||||
        forcedConfigPath = base_config_file if len(model_filename) > 1 else None
 | 
			
		||||
        # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
 | 
			
		||||
        # model_filename[1] = xmodel_filename
 | 
			
		||||
 | 
			
		||||
        self.model = self.model2 = None
 | 
			
		||||
        source =  model_def.get("source", None)
 | 
			
		||||
        source2 = model_def.get("source2", None)
 | 
			
		||||
        module_source =  model_def.get("module_source", None)
 | 
			
		||||
        module_source2 =  model_def.get("module_source2", None)
 | 
			
		||||
        if module_source is not None:
 | 
			
		||||
            model_filename = [] + model_filename
 | 
			
		||||
            model_filename[1] = module_source
 | 
			
		||||
            self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
        elif source is not None:
 | 
			
		||||
            self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
        if module_source2 is not None:
 | 
			
		||||
            self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
        if source is not None:
 | 
			
		||||
            self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
 | 
			
		||||
        elif self.transformer_switch:
 | 
			
		||||
            shared_modules= {}
 | 
			
		||||
            self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath,  return_shared_modules= shared_modules)
 | 
			
		||||
            self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
            shared_modules = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
        
 | 
			
		||||
        # self.model = offload.load_model_data(self.model, xmodel_filename )
 | 
			
		||||
        # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
 | 
			
		||||
        if source2 is not None:
 | 
			
		||||
            self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
 | 
			
		||||
 | 
			
		||||
        self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
 | 
			
		||||
        offload.change_dtype(self.model, dtype, True)
 | 
			
		||||
        if self.model is not None or self.model2 is not None:
 | 
			
		||||
            from wgp import save_model
 | 
			
		||||
            from mmgp.safetensors2 import torch_load_file
 | 
			
		||||
        else:
 | 
			
		||||
            if self.transformer_switch:
 | 
			
		||||
                if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]:
 | 
			
		||||
                    raise Exception("Shared and non shared modules at the same time across multipe models is not supported")
 | 
			
		||||
                
 | 
			
		||||
                if 0 in submodel_no_list[2:]:
 | 
			
		||||
                    shared_modules= {}
 | 
			
		||||
                    self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath,  return_shared_modules= shared_modules)
 | 
			
		||||
                    self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
                    shared_modules = None
 | 
			
		||||
                else:
 | 
			
		||||
                    modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ]
 | 
			
		||||
                    modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ]
 | 
			
		||||
                    self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
                    self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if self.model is not None:
 | 
			
		||||
            self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
 | 
			
		||||
            offload.change_dtype(self.model, dtype, True)
 | 
			
		||||
            self.model.eval().requires_grad_(False)
 | 
			
		||||
        if self.model2 is not None:
 | 
			
		||||
            self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
 | 
			
		||||
            offload.change_dtype(self.model2, dtype, True)
 | 
			
		||||
 | 
			
		||||
        # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
 | 
			
		||||
        # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors",  config_file_path=base_config_file)
 | 
			
		||||
        # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
        if self.model2 is not None:
 | 
			
		||||
            self.model2.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if module_source is not None:
 | 
			
		||||
            from wgp import save_model
 | 
			
		||||
            from mmgp.safetensors2 import torch_load_file
 | 
			
		||||
            filter = list(torch_load_file(module_source))
 | 
			
		||||
            save_model(self.model, model_type, dtype, None, is_module=True, filter=filter)
 | 
			
		||||
        elif not source is None:
 | 
			
		||||
            from wgp import save_model
 | 
			
		||||
            save_model(self.model, model_type, dtype, None)
 | 
			
		||||
            save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1)
 | 
			
		||||
        if module_source2 is not None:
 | 
			
		||||
            save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2)
 | 
			
		||||
        if not source is None:
 | 
			
		||||
            save_model(self.model, model_type, dtype, None, submodel_no= 1)
 | 
			
		||||
        if not source2 is None:
 | 
			
		||||
            save_model(self.model2, model_type, dtype, None, submodel_no= 2)
 | 
			
		||||
 | 
			
		||||
        if save_quantized:
 | 
			
		||||
            from wgp import save_quantized_model
 | 
			
		||||
            save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
 | 
			
		||||
            if self.model is not None:
 | 
			
		||||
                save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
 | 
			
		||||
            if self.model2 is not None:
 | 
			
		||||
                save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
        if self.model.config.get("vace_in_dim", None) != None:
 | 
			
		||||
            self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
 | 
			
		||||
                                            min_area=480*832,
 | 
			
		||||
                                            max_area=480*832,
 | 
			
		||||
                                            min_fps=config.sample_fps,
 | 
			
		||||
                                            max_fps=config.sample_fps,
 | 
			
		||||
                                            zero_start=True,
 | 
			
		||||
                                            seq_len=32760,
 | 
			
		||||
                                            keep_last=True)
 | 
			
		||||
 | 
			
		||||
        if hasattr(self.model, "vace_blocks"):
 | 
			
		||||
            self.adapt_vace_model(self.model)
 | 
			
		||||
            if self.model2 is not None: self.adapt_vace_model(self.model2)
 | 
			
		||||
 | 
			
		||||
        if hasattr(self.model, "face_adapter"):
 | 
			
		||||
            self.adapt_animate_model(self.model)
 | 
			
		||||
            if self.model2 is not None: self.adapt_animate_model(self.model2)
 | 
			
		||||
        
 | 
			
		||||
        self.num_timesteps = 1000 
 | 
			
		||||
        self.use_timestep_transform = True 
 | 
			
		||||
 | 
			
		||||
    def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(frames)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(frames) == len(ref_images)
 | 
			
		||||
        ref_images = [ref_images] * len(frames)
 | 
			
		||||
 | 
			
		||||
        if masks is None:
 | 
			
		||||
            latents = self.vae.encode(frames, tile_size = tile_size)
 | 
			
		||||
@ -225,11 +236,7 @@ class WanAny2V:
 | 
			
		||||
        return cat_latents
 | 
			
		||||
 | 
			
		||||
    def vace_encode_masks(self, masks, ref_images=None):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(masks)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(masks) == len(ref_images)
 | 
			
		||||
 | 
			
		||||
        ref_images = [ref_images] * len(masks)
 | 
			
		||||
        result_masks = []
 | 
			
		||||
        for mask, refs in zip(masks, ref_images):
 | 
			
		||||
            c, depth, height, width = mask.shape
 | 
			
		||||
@ -257,119 +264,6 @@ class WanAny2V:
 | 
			
		||||
            result_masks.append(mask)
 | 
			
		||||
        return result_masks
 | 
			
		||||
 | 
			
		||||
    def vace_latent(self, z, m):
 | 
			
		||||
        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
 | 
			
		||||
 | 
			
		||||
    def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
 | 
			
		||||
        from shared.utils.utils import save_image
 | 
			
		||||
        ref_width, ref_height = ref_img.size
 | 
			
		||||
        if (ref_height, ref_width) == image_size and outpainting_dims  == None:
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
            canvas = torch.zeros_like(ref_img) if return_mask else None
 | 
			
		||||
        else:
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                final_height, final_width = image_size
 | 
			
		||||
                canvas_height, canvas_width, margin_top, margin_left =   get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 8)        
 | 
			
		||||
            else:
 | 
			
		||||
                canvas_height, canvas_width = image_size
 | 
			
		||||
            scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
            new_height = int(ref_height * scale)
 | 
			
		||||
            new_width = int(ref_width * scale)
 | 
			
		||||
            if fill_max  and (canvas_height - new_height) < 16:
 | 
			
		||||
                new_height = canvas_height
 | 
			
		||||
            if fill_max  and (canvas_width - new_width) < 16:
 | 
			
		||||
                new_width = canvas_width
 | 
			
		||||
            top = (canvas_height - new_height) // 2
 | 
			
		||||
            left = (canvas_width - new_width) // 2
 | 
			
		||||
            ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img 
 | 
			
		||||
            else:
 | 
			
		||||
                canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                canvas[:, :, top:top + new_height, left:left + new_width] = ref_img 
 | 
			
		||||
            ref_img = canvas
 | 
			
		||||
            canvas = None
 | 
			
		||||
            if return_mask:
 | 
			
		||||
                if outpainting_dims != None:
 | 
			
		||||
                    canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                    canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
 | 
			
		||||
                else:
 | 
			
		||||
                    canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                    canvas[:, :, top:top + new_height, left:left + new_width] = 0
 | 
			
		||||
                canvas = canvas.to(device)
 | 
			
		||||
        return ref_img.to(device), canvas
 | 
			
		||||
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, keep_video_guide_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
 | 
			
		||||
        image_sizes = []
 | 
			
		||||
        trim_video_guide = len(keep_video_guide_frames)
 | 
			
		||||
        def conv_tensor(t, device):
 | 
			
		||||
            return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
 | 
			
		||||
 | 
			
		||||
        for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
 | 
			
		||||
            prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
 | 
			
		||||
            num_frames = total_frames - prepend_count            
 | 
			
		||||
            num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
 | 
			
		||||
            if sub_src_mask is not None and sub_src_video is not None:
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
 | 
			
		||||
                # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127  in [0, 255])
 | 
			
		||||
                # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127  in [0, 255]) and 1 = Inpainting (in fact 255  in [0, 255])
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            elif sub_src_video is None:
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
 | 
			
		||||
                else:
 | 
			
		||||
                    src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
 | 
			
		||||
                    src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                image_sizes.append(image_size)
 | 
			
		||||
            else:
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            for k, keep in enumerate(keep_video_guide_frames):
 | 
			
		||||
                if not keep:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[i][:, pos:pos+1] = 0
 | 
			
		||||
                    src_mask[i][:, pos:pos+1] = 1
 | 
			
		||||
 | 
			
		||||
            for k, frame in enumerate(inject_frames):
 | 
			
		||||
                if frame != None:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.background_mask = None
 | 
			
		||||
        for i, ref_images in enumerate(src_ref_images):
 | 
			
		||||
            if ref_images is not None:
 | 
			
		||||
                image_size = image_sizes[i]
 | 
			
		||||
                for j, ref_img in enumerate(ref_images):
 | 
			
		||||
                    if ref_img is not None and not torch.is_tensor(ref_img):
 | 
			
		||||
                        if j==0 and any_background_ref:
 | 
			
		||||
                            if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) 
 | 
			
		||||
                            src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
 | 
			
		||||
                        else:
 | 
			
		||||
                            src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
 | 
			
		||||
        if self.background_mask != None:
 | 
			
		||||
            self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
 | 
			
		||||
        return src_video, src_mask, src_ref_images
 | 
			
		||||
 | 
			
		||||
    def get_vae_latents(self, ref_images, device, tile_size= 0):
 | 
			
		||||
        ref_vae_latents = []
 | 
			
		||||
@ -380,12 +274,28 @@ class WanAny2V:
 | 
			
		||||
                    
 | 
			
		||||
        return torch.cat(ref_vae_latents, dim=1)
 | 
			
		||||
 | 
			
		||||
    def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0,  device="cuda"):
 | 
			
		||||
        if mask_pixel_values is None:
 | 
			
		||||
            msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest')
 | 
			
		||||
 | 
			
		||||
        if nb_frames_unchanged >0:
 | 
			
		||||
            msk[:, :nb_frames_unchanged] = 1
 | 
			
		||||
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
 | 
			
		||||
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
 | 
			
		||||
        msk = msk.transpose(1,2)[0]
 | 
			
		||||
        return msk
 | 
			
		||||
    
 | 
			
		||||
    def generate(self,
 | 
			
		||||
        input_prompt,
 | 
			
		||||
        input_frames= None,
 | 
			
		||||
        input_frames2= None,
 | 
			
		||||
        input_masks = None,
 | 
			
		||||
        input_ref_images = None,      
 | 
			
		||||
        input_masks2 = None,
 | 
			
		||||
        input_ref_images = None,
 | 
			
		||||
        input_ref_masks = None,
 | 
			
		||||
        input_faces = None,
 | 
			
		||||
        input_video = None,
 | 
			
		||||
        image_start = None,
 | 
			
		||||
        image_end = None,
 | 
			
		||||
@ -453,7 +363,8 @@ class WanAny2V:
 | 
			
		||||
            timesteps.append(0.)
 | 
			
		||||
            timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
 | 
			
		||||
            if self.use_timestep_transform:
 | 
			
		||||
                timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]    
 | 
			
		||||
                timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
 | 
			
		||||
            timesteps = torch.tensor(timesteps)
 | 
			
		||||
            sample_scheduler = None                  
 | 
			
		||||
        elif sample_solver == 'causvid':
 | 
			
		||||
            sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
 | 
			
		||||
@ -496,6 +407,8 @@ class WanAny2V:
 | 
			
		||||
        text_len = self.model.text_len
 | 
			
		||||
        context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) 
 | 
			
		||||
        context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) 
 | 
			
		||||
        if input_video is not None: height, width = input_video.shape[-2:]
 | 
			
		||||
 | 
			
		||||
        # NAG_prompt =  "static, low resolution, blurry"
 | 
			
		||||
        # context_NAG = self.text_encoder([NAG_prompt], self.device)[0]
 | 
			
		||||
        # context_NAG = context_NAG.to(self.dtype)
 | 
			
		||||
@ -516,109 +429,76 @@ class WanAny2V:
 | 
			
		||||
        infinitetalk = model_type in ["infinitetalk"]
 | 
			
		||||
        standin = model_type in ["standin", "vace_standin_14B"]
 | 
			
		||||
        recam = model_type in ["recam_1.3B"]
 | 
			
		||||
        ti2v = model_type in ["ti2v_2_2"]
 | 
			
		||||
        ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
 | 
			
		||||
        lucy_edit=  model_type in ["lucy_edit"]
 | 
			
		||||
        animate=  model_type in ["animate"]
 | 
			
		||||
        start_step_no = 0
 | 
			
		||||
        ref_images_count = 0
 | 
			
		||||
        trim_frames = 0
 | 
			
		||||
        extended_overlapped_latents = None
 | 
			
		||||
        extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None
 | 
			
		||||
        no_noise_latents_injection = infinitetalk
 | 
			
		||||
        timestep_injection = False
 | 
			
		||||
        lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
 | 
			
		||||
        extended_input_dim = 0
 | 
			
		||||
        ref_images_before = False
 | 
			
		||||
        # image2video 
 | 
			
		||||
        if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
 | 
			
		||||
            any_end_frame = False
 | 
			
		||||
            if image_start is None:
 | 
			
		||||
                if infinitetalk:
 | 
			
		||||
                    if input_frames is not None:
 | 
			
		||||
                        image_ref = input_frames[:, -1]
 | 
			
		||||
                        if input_video is None: input_video = input_frames[:, -1:] 
 | 
			
		||||
                        new_shot = "Q" in video_prompt_type
 | 
			
		||||
                    else:
 | 
			
		||||
                        if pre_video_frame is None:
 | 
			
		||||
                            new_shot = True
 | 
			
		||||
                        else:
 | 
			
		||||
                            if input_ref_images is None:
 | 
			
		||||
                                input_ref_images, new_shot = [pre_video_frame], False
 | 
			
		||||
                            else:
 | 
			
		||||
                                input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
 | 
			
		||||
                        if input_ref_images is None: raise Exception("Missing Reference Image")
 | 
			
		||||
                        new_shot = new_shot and window_no <= len(input_ref_images)
 | 
			
		||||
                        image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
 | 
			
		||||
                    if new_shot:  
 | 
			
		||||
                        input_video = image_ref.unsqueeze(1)
 | 
			
		||||
                    else:
 | 
			
		||||
                        color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
 | 
			
		||||
                _ , preframes_count, height, width = input_video.shape
 | 
			
		||||
                input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
 | 
			
		||||
                if infinitetalk:
 | 
			
		||||
                    image_for_clip = image_ref.to(input_video)
 | 
			
		||||
                    control_pre_frames_count = 1 
 | 
			
		||||
                    control_video = image_for_clip.unsqueeze(1)
 | 
			
		||||
            if infinitetalk:
 | 
			
		||||
                new_shot = "Q" in video_prompt_type
 | 
			
		||||
                if input_frames is not None:
 | 
			
		||||
                    image_ref = input_frames[:, 0]
 | 
			
		||||
                else:
 | 
			
		||||
                    image_for_clip = input_video[:, -1]
 | 
			
		||||
                    control_pre_frames_count = preframes_count
 | 
			
		||||
                    control_video = input_video
 | 
			
		||||
                lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
                if hasattr(self, "clip"):
 | 
			
		||||
                    clip_image_size = self.clip.model.image_size
 | 
			
		||||
                    clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :]
 | 
			
		||||
                    clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ])
 | 
			
		||||
                    clip_image = None
 | 
			
		||||
                    if input_ref_images is None:                        
 | 
			
		||||
                        if pre_video_frame is None: raise Exception("Missing Reference Image")
 | 
			
		||||
                        input_ref_images, new_shot = [pre_video_frame], False
 | 
			
		||||
                    new_shot = new_shot and window_no <= len(input_ref_images)
 | 
			
		||||
                    image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
 | 
			
		||||
                if new_shot or input_video is None:  
 | 
			
		||||
                    input_video = image_ref.unsqueeze(1)
 | 
			
		||||
                else:
 | 
			
		||||
                    clip_context = None
 | 
			
		||||
                enc =  torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width), 
 | 
			
		||||
                                    device=self.device, dtype= self.VAE_dtype)], 
 | 
			
		||||
                                    dim = 1).to(self.device)
 | 
			
		||||
                color_reference_frame = image_for_clip.unsqueeze(1).clone()
 | 
			
		||||
                    color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
 | 
			
		||||
            _ , preframes_count, height, width = input_video.shape
 | 
			
		||||
            input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
 | 
			
		||||
            if infinitetalk:
 | 
			
		||||
                image_start = image_ref.to(input_video)
 | 
			
		||||
                control_pre_frames_count = 1 
 | 
			
		||||
                control_video = image_start.unsqueeze(1)
 | 
			
		||||
            else:
 | 
			
		||||
                preframes_count = control_pre_frames_count = 1
 | 
			
		||||
                any_end_frame = image_end is not None 
 | 
			
		||||
                add_frames_for_end_image = any_end_frame and model_type == "i2v"
 | 
			
		||||
                if any_end_frame:
 | 
			
		||||
                    if add_frames_for_end_image:
 | 
			
		||||
                        frame_num +=1
 | 
			
		||||
                        lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
 | 
			
		||||
                        trim_frames = 1
 | 
			
		||||
                
 | 
			
		||||
                height, width = image_start.shape[1:]
 | 
			
		||||
                image_start = input_video[:, -1]
 | 
			
		||||
                control_pre_frames_count = preframes_count
 | 
			
		||||
                control_video = input_video
 | 
			
		||||
 | 
			
		||||
                lat_h = round(
 | 
			
		||||
                    height // self.vae_stride[1] //
 | 
			
		||||
                    self.patch_size[1] * self.patch_size[1])
 | 
			
		||||
                lat_w = round(
 | 
			
		||||
                    width // self.vae_stride[2] //
 | 
			
		||||
                    self.patch_size[2] * self.patch_size[2])
 | 
			
		||||
                height = lat_h * self.vae_stride[1]
 | 
			
		||||
                width = lat_w * self.vae_stride[2]
 | 
			
		||||
                image_start_frame = image_start.unsqueeze(1).to(self.device)
 | 
			
		||||
                color_reference_frame = image_start_frame.clone()
 | 
			
		||||
                if image_end is not None:
 | 
			
		||||
                    img_end_frame = image_end.unsqueeze(1).to(self.device)
 | 
			
		||||
            color_reference_frame = image_start.unsqueeze(1).clone()
 | 
			
		||||
 | 
			
		||||
                if hasattr(self, "clip"):                                   
 | 
			
		||||
                    clip_image_size = self.clip.model.image_size
 | 
			
		||||
                    image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
 | 
			
		||||
                    if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
 | 
			
		||||
                    if model_type == "flf2v_720p":                    
 | 
			
		||||
                        clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
 | 
			
		||||
                    else:
 | 
			
		||||
                        clip_context = self.clip.visual([image_start[:, None, :, :]])
 | 
			
		||||
                else:
 | 
			
		||||
                    clip_context = None
 | 
			
		||||
            any_end_frame = image_end is not None 
 | 
			
		||||
            add_frames_for_end_image = any_end_frame and model_type == "i2v"
 | 
			
		||||
            if any_end_frame:
 | 
			
		||||
                color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
 | 
			
		||||
                if add_frames_for_end_image:
 | 
			
		||||
                    frame_num +=1
 | 
			
		||||
                    lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
 | 
			
		||||
                    trim_frames = 1
 | 
			
		||||
 | 
			
		||||
                if any_end_frame:
 | 
			
		||||
                    enc= torch.concat([
 | 
			
		||||
                            image_start_frame,
 | 
			
		||||
                            torch.zeros( (3, frame_num-2,  height, width), device=self.device, dtype= self.VAE_dtype),
 | 
			
		||||
                            img_end_frame,
 | 
			
		||||
                    ], dim=1).to(self.device)
 | 
			
		||||
                else:
 | 
			
		||||
                    enc= torch.concat([
 | 
			
		||||
                            image_start_frame,
 | 
			
		||||
                            torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype)
 | 
			
		||||
                    ], dim=1).to(self.device)
 | 
			
		||||
            lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
 | 
			
		||||
                image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None
 | 
			
		||||
            if image_end is not None:
 | 
			
		||||
                img_end_frame = image_end.unsqueeze(1).to(self.device)
 | 
			
		||||
            clip_image_start, clip_image_end = image_start, image_end
 | 
			
		||||
 | 
			
		||||
            if any_end_frame:
 | 
			
		||||
                enc= torch.concat([
 | 
			
		||||
                        control_video,
 | 
			
		||||
                        torch.zeros( (3, frame_num-control_pre_frames_count-1,  height, width), device=self.device, dtype= self.VAE_dtype),
 | 
			
		||||
                        img_end_frame,
 | 
			
		||||
                ], dim=1).to(self.device)
 | 
			
		||||
            else:
 | 
			
		||||
                enc= torch.concat([
 | 
			
		||||
                        control_video,
 | 
			
		||||
                        torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype)
 | 
			
		||||
                ], dim=1).to(self.device)
 | 
			
		||||
 | 
			
		||||
            image_start = image_end = img_end_frame = image_ref = control_video = None
 | 
			
		||||
 | 
			
		||||
            msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
 | 
			
		||||
            if any_end_frame:
 | 
			
		||||
@ -643,42 +523,85 @@ class WanAny2V:
 | 
			
		||||
                if infinitetalk:
 | 
			
		||||
                    lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
 | 
			
		||||
                extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
 | 
			
		||||
            # if control_pre_frames_count != pre_frames_count:
 | 
			
		||||
 | 
			
		||||
            lat_y = input_video = None
 | 
			
		||||
            kwargs.update({ 'y': y})
 | 
			
		||||
            if not clip_context is None:
 | 
			
		||||
                kwargs.update({'clip_fea': clip_context})
 | 
			
		||||
 | 
			
		||||
        # Recam Master
 | 
			
		||||
        # Animate
 | 
			
		||||
        if animate:
 | 
			
		||||
            pose_pixels = input_frames * input_masks
 | 
			
		||||
            input_masks = 1. - input_masks
 | 
			
		||||
            pose_pixels -= input_masks
 | 
			
		||||
            pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
 | 
			
		||||
            input_frames = input_frames * input_masks
 | 
			
		||||
            if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
 | 
			
		||||
            if prefix_frames_count > 0:
 | 
			
		||||
                 input_frames[:, :prefix_frames_count] = input_video 
 | 
			
		||||
                 input_masks[:, :prefix_frames_count] = 1 
 | 
			
		||||
            # save_video(pose_pixels, "pose.mp4")
 | 
			
		||||
            # save_video(input_frames, "input_frames.mp4")
 | 
			
		||||
            # save_video(input_masks, "input_masks.mp4", value_range=(0,1))
 | 
			
		||||
            lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
            msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) 
 | 
			
		||||
            msk_control =  self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
 | 
			
		||||
            msk = torch.concat([msk_ref, msk_control], dim=1)
 | 
			
		||||
            image_ref = input_ref_images[0].to(self.device)
 | 
			
		||||
            clip_image_start = image_ref.squeeze(1)
 | 
			
		||||
            lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
 | 
			
		||||
            y = torch.concat([msk, lat_y])
 | 
			
		||||
            kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
 | 
			
		||||
            lat_y = msk = msk_control = msk_ref = pose_pixels = None
 | 
			
		||||
            ref_images_before = True
 | 
			
		||||
            ref_images_count = 1
 | 
			
		||||
            lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
 | 
			
		||||
                        
 | 
			
		||||
        # Clip image
 | 
			
		||||
        if hasattr(self, "clip") and clip_image_start is not None:                                   
 | 
			
		||||
            clip_image_size = self.clip.model.image_size
 | 
			
		||||
            clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
 | 
			
		||||
            clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
 | 
			
		||||
            if model_type == "flf2v_720p":                    
 | 
			
		||||
                clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
 | 
			
		||||
            else:
 | 
			
		||||
                clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
 | 
			
		||||
            clip_image_start = clip_image_end = None
 | 
			
		||||
            kwargs.update({'clip_fea': clip_context})
 | 
			
		||||
 | 
			
		||||
        # Recam Master & Lucy Edit
 | 
			
		||||
        if recam or lucy_edit:
 | 
			
		||||
            frame_num, height,width = input_frames.shape[-3:]
 | 
			
		||||
            lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
 | 
			
		||||
            frame_num = (lat_frames -1) * self.vae_stride[0] + 1
 | 
			
		||||
            input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
 | 
			
		||||
            extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
            extended_input_dim = 2 if recam else 1
 | 
			
		||||
            del input_frames
 | 
			
		||||
 | 
			
		||||
        if recam:
 | 
			
		||||
            # should be be in fact in input_frames since it is control video not a video to be extended
 | 
			
		||||
            target_camera = model_mode
 | 
			
		||||
            height,width = input_video.shape[-2:]
 | 
			
		||||
            input_video = input_video.to(dtype=self.dtype , device=self.device)
 | 
			
		||||
            source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
            del input_video
 | 
			
		||||
            # Process target camera (recammaster)
 | 
			
		||||
            target_camera = model_mode
 | 
			
		||||
            from shared.utils.cammmaster_tools import get_camera_embedding
 | 
			
		||||
            cam_emb = get_camera_embedding(target_camera)       
 | 
			
		||||
            cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
            kwargs['cam_emb'] = cam_emb
 | 
			
		||||
 | 
			
		||||
        # Video 2 Video
 | 
			
		||||
        if denoising_strength < 1. and input_frames != None:
 | 
			
		||||
        if "G" in video_prompt_type and input_frames != None:
 | 
			
		||||
            height, width = input_frames.shape[-2:]
 | 
			
		||||
            source_latents = self.vae.encode([input_frames])[0].unsqueeze(0)
 | 
			
		||||
            injection_denoising_step = 0
 | 
			
		||||
            inject_from_start = False
 | 
			
		||||
            if input_frames != None and denoising_strength < 1 :
 | 
			
		||||
                color_reference_frame = input_frames[:, -1:].clone()
 | 
			
		||||
                if overlapped_latents != None:
 | 
			
		||||
                    overlapped_latents_frames_num = overlapped_latents.shape[2]
 | 
			
		||||
                    overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
 | 
			
		||||
                if prefix_frames_count > 0:
 | 
			
		||||
                    overlapped_frames_num = prefix_frames_count
 | 
			
		||||
                    overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 
 | 
			
		||||
                    # overlapped_latents_frames_num = overlapped_latents.shape[2]
 | 
			
		||||
                    # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
 | 
			
		||||
                else: 
 | 
			
		||||
                    overlapped_latents_frames_num = overlapped_frames_num  = 0
 | 
			
		||||
                if len(keep_frames_parsed) == 0  or image_outputs or  (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] 
 | 
			
		||||
                injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
 | 
			
		||||
                injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) )
 | 
			
		||||
                latent_keep_frames = []
 | 
			
		||||
                if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0:
 | 
			
		||||
                    inject_from_start = True
 | 
			
		||||
@ -694,14 +617,21 @@ class WanAny2V:
 | 
			
		||||
                    if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
 | 
			
		||||
                    injection_denoising_step = 0
 | 
			
		||||
 | 
			
		||||
            if input_masks is not None and not "U" in video_prompt_type:
 | 
			
		||||
                image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0)
 | 
			
		||||
                if image_mask_latents.shape[2] !=1:
 | 
			
		||||
                    image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2)
 | 
			
		||||
                image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device)
 | 
			
		||||
                # save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) )
 | 
			
		||||
                # image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        # Phantom
 | 
			
		||||
        if phantom:
 | 
			
		||||
            input_ref_images_neg = None
 | 
			
		||||
            if input_ref_images != None: # Phantom Ref images
 | 
			
		||||
                input_ref_images = self.get_vae_latents(input_ref_images, self.device)
 | 
			
		||||
                input_ref_images_neg = torch.zeros_like(input_ref_images)
 | 
			
		||||
                ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
 | 
			
		||||
                trim_frames = input_ref_images.shape[1]
 | 
			
		||||
            lat_input_ref_images_neg = None
 | 
			
		||||
            if input_ref_images is not None: # Phantom Ref images
 | 
			
		||||
                lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
 | 
			
		||||
                lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
 | 
			
		||||
                ref_images_count = trim_frames = lat_input_ref_images.shape[1]
 | 
			
		||||
 | 
			
		||||
        if ti2v:
 | 
			
		||||
            if input_video is None:
 | 
			
		||||
@ -710,28 +640,29 @@ class WanAny2V:
 | 
			
		||||
                height, width = input_video.shape[-2:]
 | 
			
		||||
                source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
 | 
			
		||||
                timestep_injection = True
 | 
			
		||||
                if extended_input_dim > 0:
 | 
			
		||||
                    extended_latents[:, :, :source_latents.shape[2]] = source_latents
 | 
			
		||||
 | 
			
		||||
        # Vace
 | 
			
		||||
        if vace :
 | 
			
		||||
            # vace context encode
 | 
			
		||||
            input_frames = [u.to(self.device) for u in input_frames]
 | 
			
		||||
            input_ref_images = [ None if u == None else [v.to(self.device) for v in u]  for u in input_ref_images]
 | 
			
		||||
            input_masks = [u.to(self.device) for u in input_masks]
 | 
			
		||||
            if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
 | 
			
		||||
            input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])            
 | 
			
		||||
            input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])         
 | 
			
		||||
            input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
 | 
			
		||||
            input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
 | 
			
		||||
            ref_images_before = True
 | 
			
		||||
            z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
 | 
			
		||||
            m0 = self.vace_encode_masks(input_masks, input_ref_images)
 | 
			
		||||
            if self.background_mask != None:
 | 
			
		||||
                color_reference_frame = input_ref_images[0][0].clone()
 | 
			
		||||
                zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
 | 
			
		||||
                mbg = self.vace_encode_masks(self.background_mask, None)
 | 
			
		||||
            if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
 | 
			
		||||
                color_reference_frame = input_ref_images[0].clone()
 | 
			
		||||
                zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
 | 
			
		||||
                mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
 | 
			
		||||
                for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
 | 
			
		||||
                    zz0[:, 0:1] = zzbg
 | 
			
		||||
                    mm0[:, 0:1] = mmbg
 | 
			
		||||
 | 
			
		||||
                self.background_mask = zz0 = mm0 = zzbg = mmbg = None
 | 
			
		||||
            z = self.vace_latent(z0, m0)
 | 
			
		||||
 | 
			
		||||
            ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
 | 
			
		||||
                zz0 = mm0 = zzbg = mmbg = None
 | 
			
		||||
            z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
 | 
			
		||||
            ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
 | 
			
		||||
            context_scale = context_scale if context_scale != None else [1.0] * len(z)
 | 
			
		||||
            kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
 | 
			
		||||
            if overlapped_latents != None :
 | 
			
		||||
@ -739,17 +670,12 @@ class WanAny2V:
 | 
			
		||||
                extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
 | 
			
		||||
            if prefix_frames_count > 0:
 | 
			
		||||
                color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
 | 
			
		||||
        lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
        target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
 | 
			
		||||
 | 
			
		||||
            target_shape = list(z0[0].shape)
 | 
			
		||||
            target_shape[0] = int(target_shape[0] / 2)
 | 
			
		||||
            lat_h, lat_w = target_shape[-2:] 
 | 
			
		||||
            height = self.vae_stride[1] * lat_h
 | 
			
		||||
            width = self.vae_stride[2] * lat_w
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
 | 
			
		||||
 | 
			
		||||
        if multitalk and audio_proj != None:
 | 
			
		||||
        if multitalk:
 | 
			
		||||
            if audio_proj is None:
 | 
			
		||||
                audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ] 
 | 
			
		||||
            from .multitalk.multitalk import get_target_masks
 | 
			
		||||
            audio_proj = [audio.to(self.dtype) for audio in audio_proj]
 | 
			
		||||
            human_no = len(audio_proj[0])
 | 
			
		||||
@ -764,9 +690,9 @@ class WanAny2V:
 | 
			
		||||
 | 
			
		||||
        expand_shape = [batch_size] + [-1] * len(target_shape)
 | 
			
		||||
        # Ropes
 | 
			
		||||
        if target_camera != None:
 | 
			
		||||
        if extended_input_dim>=2:
 | 
			
		||||
            shape = list(target_shape[1:])
 | 
			
		||||
            shape[0] *= 2
 | 
			
		||||
            shape[extended_input_dim-2] *= 2
 | 
			
		||||
            freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) 
 | 
			
		||||
        else:
 | 
			
		||||
            freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) 
 | 
			
		||||
@ -777,19 +703,20 @@ class WanAny2V:
 | 
			
		||||
        if standin:
 | 
			
		||||
            from preprocessing.face_preprocessor  import FaceProcessor 
 | 
			
		||||
            standin_ref_pos = 1 if "K" in video_prompt_type else 0
 | 
			
		||||
            if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image") 
 | 
			
		||||
            standin_ref_pos = -1
 | 
			
		||||
            image_ref = original_input_ref_images[standin_ref_pos]
 | 
			
		||||
            image_ref.save("si.png")
 | 
			
		||||
            # face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2")
 | 
			
		||||
            face_processor = FaceProcessor()
 | 
			
		||||
            standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
 | 
			
		||||
            face_processor = None
 | 
			
		||||
            gc.collect()
 | 
			
		||||
            torch.cuda.empty_cache()
 | 
			
		||||
            standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) 
 | 
			
		||||
            standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
 | 
			
		||||
            kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) 
 | 
			
		||||
            if len(original_input_ref_images) < standin_ref_pos + 1: 
 | 
			
		||||
                if "I" in video_prompt_type and model_type in ["vace_standin_14B"]:
 | 
			
		||||
                    print("Warning: Missing Standin ref image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images.")
 | 
			
		||||
            else: 
 | 
			
		||||
                standin_ref_pos = -1
 | 
			
		||||
                image_ref = original_input_ref_images[standin_ref_pos]
 | 
			
		||||
                face_processor = FaceProcessor()
 | 
			
		||||
                standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
 | 
			
		||||
                face_processor = None
 | 
			
		||||
                gc.collect()
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
                standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) 
 | 
			
		||||
                standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
 | 
			
		||||
                kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # Steps Skipping
 | 
			
		||||
@ -819,7 +746,7 @@ class WanAny2V:
 | 
			
		||||
        denoising_extra = ""
 | 
			
		||||
        from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps
 | 
			
		||||
 | 
			
		||||
        phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
 | 
			
		||||
        phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
 | 
			
		||||
        if len(phases_description) > 0:  set_header_text(phases_description)
 | 
			
		||||
        guidance_switch_done =  guidance_switch2_done = False
 | 
			
		||||
        if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}"
 | 
			
		||||
@ -830,8 +757,8 @@ class WanAny2V:
 | 
			
		||||
                denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}"
 | 
			
		||||
                callback(step_no-1, denoising_extra = denoising_extra)
 | 
			
		||||
            return guide_scale, guidance_switch_done, trans, denoising_extra
 | 
			
		||||
        update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
 | 
			
		||||
        if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
 | 
			
		||||
        update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
 | 
			
		||||
        if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
 | 
			
		||||
        callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)
 | 
			
		||||
 | 
			
		||||
        def clear():
 | 
			
		||||
@ -844,19 +771,22 @@ class WanAny2V:
 | 
			
		||||
            scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
 | 
			
		||||
        # b, c, lat_f, lat_h, lat_w
 | 
			
		||||
        latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
 | 
			
		||||
        if "G" in video_prompt_type: randn = latents
 | 
			
		||||
        if apg_switch != 0:  
 | 
			
		||||
            apg_momentum = -0.75
 | 
			
		||||
            apg_norm_threshold = 55
 | 
			
		||||
            text_momentumbuffer  = MomentumBuffer(apg_momentum) 
 | 
			
		||||
            audio_momentumbuffer = MomentumBuffer(apg_momentum) 
 | 
			
		||||
 | 
			
		||||
        input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        # denoising
 | 
			
		||||
        trans = self.model
 | 
			
		||||
        for i, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
            guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra)
 | 
			
		||||
            guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra)
 | 
			
		||||
            offload.set_step_no_for_lora(trans, i)
 | 
			
		||||
            offload.set_step_no_for_lora(trans, start_step_no + i)
 | 
			
		||||
            timestep = torch.stack([t])
 | 
			
		||||
 | 
			
		||||
            if timestep_injection:
 | 
			
		||||
@ -864,44 +794,43 @@ class WanAny2V:
 | 
			
		||||
                timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device)
 | 
			
		||||
                timestep[:source_latents.shape[2]] = 0
 | 
			
		||||
                        
 | 
			
		||||
            kwargs.update({"t": timestep, "current_step": start_step_no + i})  
 | 
			
		||||
            kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i })  
 | 
			
		||||
            kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
 | 
			
		||||
 | 
			
		||||
            if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
 | 
			
		||||
            if denoising_strength < 1 and i <= injection_denoising_step:
 | 
			
		||||
                sigma = t / 1000
 | 
			
		||||
                noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
 | 
			
		||||
                if inject_from_start:
 | 
			
		||||
                    new_latents = latents.clone()
 | 
			
		||||
                    new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
 | 
			
		||||
                    noisy_image = latents.clone()
 | 
			
		||||
                    noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
 | 
			
		||||
                    for latent_no, keep_latent in enumerate(latent_keep_frames):
 | 
			
		||||
                        if not keep_latent:
 | 
			
		||||
                            new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
 | 
			
		||||
                    latents = new_latents
 | 
			
		||||
                    new_latents = None
 | 
			
		||||
                            noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
 | 
			
		||||
                    latents = noisy_image
 | 
			
		||||
                    noisy_image = None
 | 
			
		||||
                else:
 | 
			
		||||
                    latents = noise * sigma + (1 - sigma) * source_latents
 | 
			
		||||
                noise = None
 | 
			
		||||
                    latents = randn * sigma + (1 - sigma) * source_latents
 | 
			
		||||
 | 
			
		||||
            if extended_overlapped_latents != None:
 | 
			
		||||
                if no_noise_latents_injection:
 | 
			
		||||
                    latents[:, :, :extended_overlapped_latents.shape[2]]   = extended_overlapped_latents 
 | 
			
		||||
                else:
 | 
			
		||||
                    latent_noise_factor = t / 1000
 | 
			
		||||
                    latents[:, :, :extended_overlapped_latents.shape[2]]   = extended_overlapped_latents  * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor 
 | 
			
		||||
                if vace:
 | 
			
		||||
                    overlap_noise_factor = overlap_noise / 1000 
 | 
			
		||||
                    for zz in z:
 | 
			
		||||
                        zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ]   = extended_overlapped_latents[0, :, ref_images_count:]  * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor 
 | 
			
		||||
 | 
			
		||||
            if target_camera != None:
 | 
			
		||||
                latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2)
 | 
			
		||||
            if extended_input_dim > 0:
 | 
			
		||||
                latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
 | 
			
		||||
            else:
 | 
			
		||||
                latent_model_input = latents
 | 
			
		||||
 | 
			
		||||
            any_guidance = guide_scale != 1
 | 
			
		||||
            if phantom:
 | 
			
		||||
                gen_args = {
 | 
			
		||||
                    "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + 
 | 
			
		||||
                        [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
 | 
			
		||||
                    "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + 
 | 
			
		||||
                        [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
 | 
			
		||||
                    "context": [context, context_null, context_null] ,
 | 
			
		||||
                }
 | 
			
		||||
            elif fantasy:
 | 
			
		||||
@ -1010,8 +939,8 @@ class WanAny2V:
 | 
			
		||||
            
 | 
			
		||||
            if sample_solver == "euler":
 | 
			
		||||
                dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
 | 
			
		||||
                dt = dt / self.num_timesteps
 | 
			
		||||
                latents = latents - noise_pred * dt[:, None, None, None, None]
 | 
			
		||||
                dt = dt.item() / self.num_timesteps
 | 
			
		||||
                latents = latents - noise_pred * dt
 | 
			
		||||
            else:
 | 
			
		||||
                latents = sample_scheduler.step(
 | 
			
		||||
                    noise_pred[:, :, :target_shape[1]],
 | 
			
		||||
@ -1019,9 +948,16 @@ class WanAny2V:
 | 
			
		||||
                    latents,
 | 
			
		||||
                    **scheduler_kwargs)[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if image_mask_latents is not None:
 | 
			
		||||
                sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000
 | 
			
		||||
                noisy_image = randn * sigma + (1 - sigma) * source_latents
 | 
			
		||||
                latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents  
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if callback is not None:
 | 
			
		||||
                latents_preview = latents
 | 
			
		||||
                if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] 
 | 
			
		||||
                if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] 
 | 
			
		||||
                if trim_frames > 0:  latents_preview=  latents_preview[:, :,:-trim_frames]
 | 
			
		||||
                if image_outputs: latents_preview=  latents_preview[:, :,:1]
 | 
			
		||||
                if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
 | 
			
		||||
@ -1032,7 +968,7 @@ class WanAny2V:
 | 
			
		||||
        if timestep_injection:
 | 
			
		||||
            latents[:, :, :source_latents.shape[2]] = source_latents
 | 
			
		||||
 | 
			
		||||
        if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
 | 
			
		||||
        if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
 | 
			
		||||
        if trim_frames > 0:  latents=  latents[:, :,:-trim_frames]
 | 
			
		||||
        if return_latent_slice != None:
 | 
			
		||||
            latent_slice = latents[:, :, return_latent_slice].clone()
 | 
			
		||||
@ -1069,4 +1005,20 @@ class WanAny2V:
 | 
			
		||||
        delattr(model, "vace_blocks")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def adapt_animate_model(self, model):
 | 
			
		||||
        modules_dict= { k: m for k, m in model.named_modules()}
 | 
			
		||||
        for animate_layer in range(8):
 | 
			
		||||
            module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
 | 
			
		||||
            model_layer = animate_layer * 5
 | 
			
		||||
            target = modules_dict[f"blocks.{model_layer}"]
 | 
			
		||||
            setattr(target, "face_adapter_fuser_blocks", module )
 | 
			
		||||
        delattr(model, "face_adapter")
 | 
			
		||||
 | 
			
		||||
    def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs):
 | 
			
		||||
        if base_model_type == "animate":
 | 
			
		||||
            if "1" in video_prompt_type:
 | 
			
		||||
                preloadURLs = get_model_recursive_prop(model_type,  "preload_URLs")
 | 
			
		||||
                if len(preloadURLs) > 0: 
 | 
			
		||||
                    return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
 | 
			
		||||
        return [], []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,11 +21,24 @@ class family_handler():
 | 
			
		||||
        extra_model_def["fps"] =fps
 | 
			
		||||
        extra_model_def["frames_minimum"] = 17
 | 
			
		||||
        extra_model_def["frames_steps"] = 20
 | 
			
		||||
        extra_model_def["latent_size"] = 4
 | 
			
		||||
        extra_model_def["sliding_window"] = True
 | 
			
		||||
        extra_model_def["skip_layer_guidance"] = True
 | 
			
		||||
        extra_model_def["tea_cache"] = True
 | 
			
		||||
        extra_model_def["guidance_max_phases"] = 1
 | 
			
		||||
 | 
			
		||||
        extra_model_def["model_modes"] = {
 | 
			
		||||
                    "choices": [
 | 
			
		||||
                        ("Synchronous", 0),
 | 
			
		||||
                        ("Asynchronous (better quality but around 50% extra steps added)", 5),
 | 
			
		||||
                        ],
 | 
			
		||||
                    "default": 0,
 | 
			
		||||
                    "label" : "Generation Type"
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = "TSV"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return extra_model_def 
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -54,7 +67,11 @@ class family_handler():
 | 
			
		||||
    def query_family_infos():
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_rgb_factors(base_model_type ):
 | 
			
		||||
        from shared.RGB_factors import get_rgb_factors
 | 
			
		||||
        latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
 | 
			
		||||
        return latent_rgb_factors, latent_rgb_factors_bias
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
 | 
			
		||||
@ -62,7 +79,7 @@ class family_handler():
 | 
			
		||||
        return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
 | 
			
		||||
        from .configs import WAN_CONFIGS
 | 
			
		||||
        from .wan_handler import family_handler
 | 
			
		||||
        cfg = WAN_CONFIGS['t2v-14B']
 | 
			
		||||
 | 
			
		||||
@ -1,479 +0,0 @@
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
from typing import List
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from typing import Union
 | 
			
		||||
import logging
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from diffusers.image_processor import PipelineImageInput
 | 
			
		||||
from diffusers.utils.torch_utils import randn_tensor
 | 
			
		||||
from diffusers.video_processor import VideoProcessor
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
from .modules.t5 import T5EncoderModel
 | 
			
		||||
from .modules.vae import WanVAE
 | 
			
		||||
from wan.modules.posemb_layers import get_rotary_pos_embed
 | 
			
		||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
 | 
			
		||||
                               get_sampling_sigmas, retrieve_timesteps)
 | 
			
		||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
 | 
			
		||||
class DTT2V:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config,
 | 
			
		||||
        checkpoint_dir,
 | 
			
		||||
        rank=0,
 | 
			
		||||
        model_filename = None,
 | 
			
		||||
        text_encoder_filename = None,
 | 
			
		||||
        quantizeTransformer = False,
 | 
			
		||||
        dtype = torch.bfloat16,
 | 
			
		||||
    ):
 | 
			
		||||
        self.device = torch.device(f"cuda")
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.rank = rank
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
			
		||||
        self.param_dtype = config.param_dtype
 | 
			
		||||
 | 
			
		||||
        self.text_encoder = T5EncoderModel(
 | 
			
		||||
            text_len=config.text_len,
 | 
			
		||||
            dtype=config.t5_dtype,
 | 
			
		||||
            device=torch.device('cpu'),
 | 
			
		||||
            checkpoint_path=text_encoder_filename,
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 | 
			
		||||
            shard_fn= None)
 | 
			
		||||
 | 
			
		||||
        self.vae_stride = config.vae_stride
 | 
			
		||||
        self.patch_size = config.patch_size 
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        self.vae = WanVAE(
 | 
			
		||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
			
		||||
            device=self.device)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {model_filename}")
 | 
			
		||||
        from mmgp import offload
 | 
			
		||||
 | 
			
		||||
        self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json")
 | 
			
		||||
        # offload.load_model_data(self.model, "recam.ckpt")
 | 
			
		||||
        # self.model.cpu()
 | 
			
		||||
        # offload.save_model(self.model, "recam.safetensors")
 | 
			
		||||
        if self.dtype == torch.float16 and not "fp16" in model_filename:
 | 
			
		||||
            self.model.to(self.dtype) 
 | 
			
		||||
        # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
 | 
			
		||||
        if self.dtype == torch.float16:
 | 
			
		||||
            self.vae.model.to(self.dtype)
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        self.scheduler = FlowUniPCMultistepScheduler()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def do_classifier_free_guidance(self) -> bool:
 | 
			
		||||
        return self._guidance_scale > 1
 | 
			
		||||
 | 
			
		||||
    def encode_image(
 | 
			
		||||
        self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        # prefix_video
 | 
			
		||||
        prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
 | 
			
		||||
        prefix_video = torch.tensor(prefix_video).unsqueeze(1)  # .to(image_embeds.dtype).unsqueeze(1)
 | 
			
		||||
        if prefix_video.dtype == torch.uint8:
 | 
			
		||||
            prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
 | 
			
		||||
        prefix_video = prefix_video.to(self.device)
 | 
			
		||||
        prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]]  # [(c, f, h, w)]
 | 
			
		||||
        if prefix_video[0].shape[1] % causal_block_size != 0:
 | 
			
		||||
            truncate_len = prefix_video[0].shape[1] % causal_block_size
 | 
			
		||||
            print("the length of prefix video is truncated for the casual block size alignment.")
 | 
			
		||||
            prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
 | 
			
		||||
        predix_video_latent_length = prefix_video[0].shape[1]
 | 
			
		||||
        return prefix_video, predix_video_latent_length
 | 
			
		||||
 | 
			
		||||
    def prepare_latents(
 | 
			
		||||
        self,
 | 
			
		||||
        shape: Tuple[int],
 | 
			
		||||
        dtype: Optional[torch.dtype] = None,
 | 
			
		||||
        device: Optional[torch.device] = None,
 | 
			
		||||
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        return randn_tensor(shape, generator, device=device, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    def generate_timestep_matrix(
 | 
			
		||||
        self,
 | 
			
		||||
        num_frames,
 | 
			
		||||
        step_template,
 | 
			
		||||
        base_num_frames,
 | 
			
		||||
        ar_step=5,
 | 
			
		||||
        num_pre_ready=0,
 | 
			
		||||
        casual_block_size=1,
 | 
			
		||||
        shrink_interval_with_mask=False,
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
 | 
			
		||||
        step_matrix, step_index = [], []
 | 
			
		||||
        update_mask, valid_interval = [], []
 | 
			
		||||
        num_iterations = len(step_template) + 1
 | 
			
		||||
        num_frames_block = num_frames // casual_block_size
 | 
			
		||||
        base_num_frames_block = base_num_frames // casual_block_size
 | 
			
		||||
        if base_num_frames_block < num_frames_block:
 | 
			
		||||
            infer_step_num = len(step_template)
 | 
			
		||||
            gen_block = base_num_frames_block
 | 
			
		||||
            min_ar_step = infer_step_num / gen_block
 | 
			
		||||
            assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
 | 
			
		||||
        # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
 | 
			
		||||
        step_template = torch.cat(
 | 
			
		||||
            [
 | 
			
		||||
                torch.tensor([999], dtype=torch.int64, device=step_template.device),
 | 
			
		||||
                step_template.long(),
 | 
			
		||||
                torch.tensor([0], dtype=torch.int64, device=step_template.device),
 | 
			
		||||
            ]
 | 
			
		||||
        )  # to handle the counter in row works starting from 1
 | 
			
		||||
        pre_row = torch.zeros(num_frames_block, dtype=torch.long)
 | 
			
		||||
        if num_pre_ready > 0:
 | 
			
		||||
            pre_row[: num_pre_ready // casual_block_size] = num_iterations
 | 
			
		||||
 | 
			
		||||
        while torch.all(pre_row >= (num_iterations - 1)) == False:
 | 
			
		||||
            new_row = torch.zeros(num_frames_block, dtype=torch.long)
 | 
			
		||||
            for i in range(num_frames_block):
 | 
			
		||||
                if i == 0 or pre_row[i - 1] >= (
 | 
			
		||||
                    num_iterations - 1
 | 
			
		||||
                ):  # the first frame or the last frame is completely denoised
 | 
			
		||||
                    new_row[i] = pre_row[i] + 1
 | 
			
		||||
                else:
 | 
			
		||||
                    new_row[i] = new_row[i - 1] - ar_step
 | 
			
		||||
            new_row = new_row.clamp(0, num_iterations)
 | 
			
		||||
 | 
			
		||||
            update_mask.append(
 | 
			
		||||
                (new_row != pre_row) & (new_row != num_iterations)
 | 
			
		||||
            )  # False: no need to update, True: need to update
 | 
			
		||||
            step_index.append(new_row)
 | 
			
		||||
            step_matrix.append(step_template[new_row])
 | 
			
		||||
            pre_row = new_row
 | 
			
		||||
 | 
			
		||||
        # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
 | 
			
		||||
        terminal_flag = base_num_frames_block
 | 
			
		||||
        if shrink_interval_with_mask:
 | 
			
		||||
            idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
 | 
			
		||||
            update_mask = update_mask[0]
 | 
			
		||||
            update_mask_idx = idx_sequence[update_mask]
 | 
			
		||||
            last_update_idx = update_mask_idx[-1].item()
 | 
			
		||||
            terminal_flag = last_update_idx + 1
 | 
			
		||||
        # for i in range(0, len(update_mask)):
 | 
			
		||||
        for curr_mask in update_mask:
 | 
			
		||||
            if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
 | 
			
		||||
                terminal_flag += 1
 | 
			
		||||
            valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
 | 
			
		||||
 | 
			
		||||
        step_update_mask = torch.stack(update_mask, dim=0)
 | 
			
		||||
        step_index = torch.stack(step_index, dim=0)
 | 
			
		||||
        step_matrix = torch.stack(step_matrix, dim=0)
 | 
			
		||||
 | 
			
		||||
        if casual_block_size > 1:
 | 
			
		||||
            step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
 | 
			
		||||
 | 
			
		||||
        return step_matrix, step_index, step_update_mask, valid_interval
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def generate(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: Union[str, List[str]],
 | 
			
		||||
        negative_prompt: Union[str, List[str]] = "",
 | 
			
		||||
        image: PipelineImageInput = None,
 | 
			
		||||
        height: int = 480,
 | 
			
		||||
        width: int = 832,
 | 
			
		||||
        num_frames: int = 97,
 | 
			
		||||
        num_inference_steps: int = 50,
 | 
			
		||||
        shift: float = 1.0,
 | 
			
		||||
        guidance_scale: float = 5.0,
 | 
			
		||||
        seed: float = 0.0,
 | 
			
		||||
        overlap_history: int = 17,
 | 
			
		||||
        addnoise_condition: int = 0,
 | 
			
		||||
        base_num_frames: int = 97,
 | 
			
		||||
        ar_step: int = 5,
 | 
			
		||||
        causal_block_size: int = 1,
 | 
			
		||||
        causal_attention: bool = False,
 | 
			
		||||
        fps: int = 24,
 | 
			
		||||
        VAE_tile_size = 0,
 | 
			
		||||
        joint_pass = False,
 | 
			
		||||
        callback = None,
 | 
			
		||||
    ):
 | 
			
		||||
        generator = torch.Generator(device=self.device)
 | 
			
		||||
        generator.manual_seed(seed)
 | 
			
		||||
        # if base_num_frames > base_num_frames:
 | 
			
		||||
        #     causal_block_size = 0
 | 
			
		||||
        self._guidance_scale = guidance_scale
 | 
			
		||||
 | 
			
		||||
        i2v_extra_kwrags = {}
 | 
			
		||||
        prefix_video = None
 | 
			
		||||
        predix_video_latent_length = 0
 | 
			
		||||
        if image:
 | 
			
		||||
            frame_width, frame_height  = image.size
 | 
			
		||||
            scale = min(height / frame_height, width /  frame_width)
 | 
			
		||||
            height = (int(frame_height * scale) // 16) * 16
 | 
			
		||||
            width = (int(frame_width * scale) // 16) * 16
 | 
			
		||||
 | 
			
		||||
            prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size)
 | 
			
		||||
 | 
			
		||||
        latent_length = (num_frames - 1) // 4 + 1
 | 
			
		||||
        latent_height = height // 8
 | 
			
		||||
        latent_width = width // 8
 | 
			
		||||
 | 
			
		||||
        prompt_embeds = self.text_encoder([prompt], self.device)
 | 
			
		||||
        prompt_embeds  = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
 | 
			
		||||
        if self.do_classifier_free_guidance:
 | 
			
		||||
            negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
 | 
			
		||||
            negative_prompt_embeds  = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
 | 
			
		||||
        init_timesteps = self.scheduler.timesteps
 | 
			
		||||
        fps_embeds = [fps] * prompt_embeds[0].shape[0]
 | 
			
		||||
        fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
 | 
			
		||||
        transformer_dtype = self.dtype
 | 
			
		||||
        # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad():
 | 
			
		||||
        if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
 | 
			
		||||
            # short video generation
 | 
			
		||||
            latent_shape = [16, latent_length, latent_height, latent_width]
 | 
			
		||||
            latents = self.prepare_latents(
 | 
			
		||||
                latent_shape, dtype=torch.float32, device=self.device, generator=generator
 | 
			
		||||
            )
 | 
			
		||||
            latents = [latents]
 | 
			
		||||
            if prefix_video is not None:
 | 
			
		||||
                latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
 | 
			
		||||
            base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
 | 
			
		||||
            step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
 | 
			
		||||
                latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
 | 
			
		||||
            )
 | 
			
		||||
            sample_schedulers = []
 | 
			
		||||
            for _ in range(latent_length):
 | 
			
		||||
                sample_scheduler = FlowUniPCMultistepScheduler(
 | 
			
		||||
                    num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
 | 
			
		||||
                )
 | 
			
		||||
                sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
 | 
			
		||||
                sample_schedulers.append(sample_scheduler)
 | 
			
		||||
            sample_schedulers_counter = [0] * latent_length
 | 
			
		||||
 | 
			
		||||
            if callback != None:
 | 
			
		||||
                callback(-1, None, True)
 | 
			
		||||
 | 
			
		||||
            freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) 
 | 
			
		||||
            for i, timestep_i in enumerate(tqdm(step_matrix)):
 | 
			
		||||
                update_mask_i = step_update_mask[i]
 | 
			
		||||
                valid_interval_i = valid_interval[i]
 | 
			
		||||
                valid_interval_start, valid_interval_end = valid_interval_i
 | 
			
		||||
                timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
 | 
			
		||||
                latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
 | 
			
		||||
                if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
 | 
			
		||||
                    noise_factor = 0.001 * addnoise_condition
 | 
			
		||||
                    timestep_for_noised_condition = addnoise_condition
 | 
			
		||||
                    latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
 | 
			
		||||
                        latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
 | 
			
		||||
                        + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
 | 
			
		||||
                        * noise_factor
 | 
			
		||||
                    )
 | 
			
		||||
                    timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
 | 
			
		||||
                kwrags = {
 | 
			
		||||
                    "x" : torch.stack([latent_model_input[0]]),
 | 
			
		||||
                    "t" : timestep,
 | 
			
		||||
                    "freqs" :freqs,
 | 
			
		||||
                    "fps" : fps_embeds,
 | 
			
		||||
                    # "causal_block_size" : causal_block_size,
 | 
			
		||||
                    "callback" : callback,
 | 
			
		||||
                    "pipeline" : self
 | 
			
		||||
                }
 | 
			
		||||
                kwrags.update(i2v_extra_kwrags)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                if not self.do_classifier_free_guidance:
 | 
			
		||||
                    noise_pred = self.model(
 | 
			
		||||
                        context=prompt_embeds,
 | 
			
		||||
                        **kwrags,
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    if self._interrupt:
 | 
			
		||||
                        return None                
 | 
			
		||||
                    noise_pred= noise_pred.to(torch.float32)                                          
 | 
			
		||||
                else:
 | 
			
		||||
                    if joint_pass:
 | 
			
		||||
                        noise_pred_cond, noise_pred_uncond = self.model(
 | 
			
		||||
                            context=prompt_embeds,
 | 
			
		||||
                            context2=negative_prompt_embeds,
 | 
			
		||||
                            **kwrags,
 | 
			
		||||
                        )
 | 
			
		||||
                        if self._interrupt:
 | 
			
		||||
                            return None
 | 
			
		||||
                    else:
 | 
			
		||||
                        noise_pred_cond = self.model(
 | 
			
		||||
                            context=prompt_embeds,
 | 
			
		||||
                            **kwrags,
 | 
			
		||||
                        )[0]
 | 
			
		||||
                        if self._interrupt:
 | 
			
		||||
                            return None                
 | 
			
		||||
                        noise_pred_uncond = self.model(
 | 
			
		||||
                            context=negative_prompt_embeds,
 | 
			
		||||
                            **kwrags,
 | 
			
		||||
                        )[0]
 | 
			
		||||
                        if self._interrupt:
 | 
			
		||||
                            return None
 | 
			
		||||
                    noise_pred_cond= noise_pred_cond.to(torch.float32)                                                                                 
 | 
			
		||||
                    noise_pred_uncond= noise_pred_uncond.to(torch.float32)                                                                                 
 | 
			
		||||
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
 | 
			
		||||
                    del noise_pred_cond, noise_pred_uncond
 | 
			
		||||
                for idx in range(valid_interval_start, valid_interval_end):
 | 
			
		||||
                    if update_mask_i[idx].item():
 | 
			
		||||
                        latents[0][:, idx] = sample_schedulers[idx].step(
 | 
			
		||||
                            noise_pred[:, idx - valid_interval_start],
 | 
			
		||||
                            timestep_i[idx],
 | 
			
		||||
                            latents[0][:, idx],
 | 
			
		||||
                            return_dict=False,
 | 
			
		||||
                            generator=generator,
 | 
			
		||||
                        )[0]
 | 
			
		||||
                        sample_schedulers_counter[idx] += 1
 | 
			
		||||
                if callback is not None:
 | 
			
		||||
                    callback(i, latents[0], False)         
 | 
			
		||||
 | 
			
		||||
            x0 = latents[0].unsqueeze(0)
 | 
			
		||||
            videos = self.vae.decode(x0, tile_size= VAE_tile_size)
 | 
			
		||||
            videos = (videos / 2 + 0.5).clamp(0, 1)
 | 
			
		||||
            videos = [video for video in videos]
 | 
			
		||||
            videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
 | 
			
		||||
            videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
 | 
			
		||||
            return videos
 | 
			
		||||
        else:
 | 
			
		||||
            # long video generation
 | 
			
		||||
            base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
 | 
			
		||||
            overlap_history_frames = (overlap_history - 1) // 4 + 1
 | 
			
		||||
            n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
 | 
			
		||||
            print(f"n_iter:{n_iter}")
 | 
			
		||||
            output_video = None
 | 
			
		||||
            for i in range(n_iter):
 | 
			
		||||
                if output_video is not None:  # i !=0
 | 
			
		||||
                    prefix_video = output_video[:, -overlap_history:].to(self.device)
 | 
			
		||||
                    prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]]  # [(c, f, h, w)]
 | 
			
		||||
                    if prefix_video[0].shape[1] % causal_block_size != 0:
 | 
			
		||||
                        truncate_len = prefix_video[0].shape[1] % causal_block_size
 | 
			
		||||
                        print("the length of prefix video is truncated for the casual block size alignment.")
 | 
			
		||||
                        prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
 | 
			
		||||
                    predix_video_latent_length = prefix_video[0].shape[1]
 | 
			
		||||
                    finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
 | 
			
		||||
                    left_frame_num = latent_length - finished_frame_num
 | 
			
		||||
                    base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
 | 
			
		||||
                else:  # i == 0
 | 
			
		||||
                    base_num_frames_iter = base_num_frames
 | 
			
		||||
                latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
 | 
			
		||||
                latents = self.prepare_latents(
 | 
			
		||||
                    latent_shape, dtype=torch.float32, device=self.device, generator=generator
 | 
			
		||||
                )
 | 
			
		||||
                latents = [latents]
 | 
			
		||||
                if prefix_video is not None:
 | 
			
		||||
                    latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
 | 
			
		||||
                step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
 | 
			
		||||
                    base_num_frames_iter,
 | 
			
		||||
                    init_timesteps,
 | 
			
		||||
                    base_num_frames_iter,
 | 
			
		||||
                    ar_step,
 | 
			
		||||
                    predix_video_latent_length,
 | 
			
		||||
                    causal_block_size,
 | 
			
		||||
                )
 | 
			
		||||
                sample_schedulers = []
 | 
			
		||||
                for _ in range(base_num_frames_iter):
 | 
			
		||||
                    sample_scheduler = FlowUniPCMultistepScheduler(
 | 
			
		||||
                        num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
 | 
			
		||||
                    )
 | 
			
		||||
                    sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
 | 
			
		||||
                    sample_schedulers.append(sample_scheduler)
 | 
			
		||||
                sample_schedulers_counter = [0] * base_num_frames_iter
 | 
			
		||||
                if callback != None:
 | 
			
		||||
                    callback(-1, None, True)
 | 
			
		||||
 | 
			
		||||
                freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) 
 | 
			
		||||
                for i, timestep_i in enumerate(tqdm(step_matrix)):
 | 
			
		||||
                    update_mask_i = step_update_mask[i]
 | 
			
		||||
                    valid_interval_i = valid_interval[i]
 | 
			
		||||
                    valid_interval_start, valid_interval_end = valid_interval_i
 | 
			
		||||
                    timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
 | 
			
		||||
                    latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
 | 
			
		||||
                    if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
 | 
			
		||||
                        noise_factor = 0.001 * addnoise_condition
 | 
			
		||||
                        timestep_for_noised_condition = addnoise_condition
 | 
			
		||||
                        latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
 | 
			
		||||
                            latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
 | 
			
		||||
                            * (1.0 - noise_factor)
 | 
			
		||||
                            + torch.randn_like(
 | 
			
		||||
                                latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
 | 
			
		||||
                            )
 | 
			
		||||
                            * noise_factor
 | 
			
		||||
                        )
 | 
			
		||||
                        timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
 | 
			
		||||
                    kwrags = {
 | 
			
		||||
                        "x" : torch.stack([latent_model_input[0]]),
 | 
			
		||||
                        "t" : timestep,
 | 
			
		||||
                        "freqs" :freqs,
 | 
			
		||||
                        "fps" : fps_embeds,
 | 
			
		||||
                        "causal_block_size" : causal_block_size,
 | 
			
		||||
                        "causal_attention" : causal_attention,
 | 
			
		||||
                        "callback" : callback,
 | 
			
		||||
                        "pipeline" : self
 | 
			
		||||
                    }
 | 
			
		||||
                    kwrags.update(i2v_extra_kwrags)
 | 
			
		||||
                        
 | 
			
		||||
                    if not self.do_classifier_free_guidance:
 | 
			
		||||
                        noise_pred = self.model(
 | 
			
		||||
                            context=prompt_embeds,
 | 
			
		||||
                            **kwrags,
 | 
			
		||||
                        )[0]
 | 
			
		||||
                        if self._interrupt:
 | 
			
		||||
                            return None
 | 
			
		||||
                        noise_pred= noise_pred.to(torch.float32)                                                                  
 | 
			
		||||
                    else:
 | 
			
		||||
                        if joint_pass:
 | 
			
		||||
                            noise_pred_cond, noise_pred_uncond = self.model(
 | 
			
		||||
                                context=prompt_embeds,
 | 
			
		||||
                                context2=negative_prompt_embeds,
 | 
			
		||||
                                **kwrags,
 | 
			
		||||
                            )
 | 
			
		||||
                            if self._interrupt:
 | 
			
		||||
                                return None                
 | 
			
		||||
                        else:
 | 
			
		||||
                            noise_pred_cond = self.model(
 | 
			
		||||
                                context=prompt_embeds,
 | 
			
		||||
                                **kwrags,
 | 
			
		||||
                            )[0]
 | 
			
		||||
                            if self._interrupt:
 | 
			
		||||
                                return None                
 | 
			
		||||
                            noise_pred_uncond = self.model(
 | 
			
		||||
                                context=negative_prompt_embeds,
 | 
			
		||||
                            )[0]
 | 
			
		||||
                            if self._interrupt:
 | 
			
		||||
                                return None
 | 
			
		||||
                        noise_pred_cond= noise_pred_cond.to(torch.float32)                                          
 | 
			
		||||
                        noise_pred_uncond= noise_pred_uncond.to(torch.float32)                                          
 | 
			
		||||
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
 | 
			
		||||
                        del noise_pred_cond, noise_pred_uncond
 | 
			
		||||
                    for idx in range(valid_interval_start, valid_interval_end):
 | 
			
		||||
                        if update_mask_i[idx].item():
 | 
			
		||||
                            latents[0][:, idx] = sample_schedulers[idx].step(
 | 
			
		||||
                                noise_pred[:, idx - valid_interval_start],
 | 
			
		||||
                                timestep_i[idx],
 | 
			
		||||
                                latents[0][:, idx],
 | 
			
		||||
                                return_dict=False,
 | 
			
		||||
                                generator=generator,
 | 
			
		||||
                            )[0]
 | 
			
		||||
                            sample_schedulers_counter[idx] += 1
 | 
			
		||||
                    if callback is not None:
 | 
			
		||||
                        callback(i, latents[0].squeeze(0), False)         
 | 
			
		||||
 | 
			
		||||
                x0 = latents[0].unsqueeze(0)
 | 
			
		||||
                videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
 | 
			
		||||
                if output_video is None:
 | 
			
		||||
                    output_video = videos[0].clamp(-1, 1).cpu()  # c, f, h, w
 | 
			
		||||
                else:
 | 
			
		||||
                    output_video = torch.cat(
 | 
			
		||||
                        [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
 | 
			
		||||
                    )  # c, f, h, w
 | 
			
		||||
            return output_video
 | 
			
		||||
@ -12,29 +12,17 @@ from diffusers.models.modeling_utils import ModelMixin
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Union,Optional
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
from shared.attention import pay_attention
 | 
			
		||||
from torch.backends.cuda import sdp_kernel
 | 
			
		||||
from ..multitalk.multitalk_utils import get_attn_map_with_target
 | 
			
		||||
from ..animate.motion_encoder import Generator
 | 
			
		||||
from ..animate.face_blocks import FaceAdapter, FaceEncoder 
 | 
			
		||||
from ..animate.model_animate import after_patch_embedding
 | 
			
		||||
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cache(cache_name):
 | 
			
		||||
    all_cache = offload.shared_state.get("_cache",  None)
 | 
			
		||||
    if all_cache is None:
 | 
			
		||||
        all_cache = {}
 | 
			
		||||
        offload.shared_state["_cache"]=  all_cache
 | 
			
		||||
    cache = offload.shared_state.get(cache_name, None)
 | 
			
		||||
    if cache is None:
 | 
			
		||||
        cache = {}
 | 
			
		||||
        offload.shared_state[cache_name] = cache
 | 
			
		||||
    return cache
 | 
			
		||||
 | 
			
		||||
def clear_caches():
 | 
			
		||||
    all_cache = offload.shared_state.get("_cache",  None)
 | 
			
		||||
    if all_cache is not None:
 | 
			
		||||
        all_cache.clear()
 | 
			
		||||
 | 
			
		||||
def sinusoidal_embedding_1d(dim, position):
 | 
			
		||||
    # preprocess
 | 
			
		||||
    assert dim % 2 == 0
 | 
			
		||||
@ -514,6 +502,7 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
        multitalk_masks=None,
 | 
			
		||||
        ref_images_count=0,
 | 
			
		||||
        standin_phase=-1,
 | 
			
		||||
        motion_vec = None,
 | 
			
		||||
    ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Args:
 | 
			
		||||
@ -579,19 +568,23 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
            y = self.norm_x(x)
 | 
			
		||||
            y = y.to(attention_dtype)
 | 
			
		||||
            if ref_images_count == 0:
 | 
			
		||||
                x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                ylist= [y]
 | 
			
		||||
                del y
 | 
			
		||||
                x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
            else:
 | 
			
		||||
                y_shape = y.shape
 | 
			
		||||
                y = y.reshape(y_shape[0], grid_sizes[0], -1)
 | 
			
		||||
                y = y[:, ref_images_count:]
 | 
			
		||||
                y = y.reshape(y_shape[0], -1, y_shape[-1])
 | 
			
		||||
                grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
 | 
			
		||||
                y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                ylist= [y]
 | 
			
		||||
                y = None
 | 
			
		||||
                y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
 | 
			
		||||
                x = x.reshape(y_shape[0], grid_sizes[0], -1)
 | 
			
		||||
                x[:, ref_images_count:] += y
 | 
			
		||||
                x = x.reshape(y_shape[0], -1, y_shape[-1])
 | 
			
		||||
            del y
 | 
			
		||||
                del y
 | 
			
		||||
 | 
			
		||||
        y = self.norm2(x)
 | 
			
		||||
 | 
			
		||||
@ -627,6 +620,10 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
                        x.add_(hint)
 | 
			
		||||
                    else:
 | 
			
		||||
                        x.add_(hint, alpha= scale)
 | 
			
		||||
 | 
			
		||||
        if motion_vec is not None and self.block_no % 5 == 0:
 | 
			
		||||
            x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
 | 
			
		||||
 | 
			
		||||
        return x 
 | 
			
		||||
 | 
			
		||||
class AudioProjModel(ModelMixin, ConfigMixin):
 | 
			
		||||
@ -909,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                 norm_input_visual=True,
 | 
			
		||||
                 norm_output_audio=True,
 | 
			
		||||
                 standin= False,
 | 
			
		||||
                 motion_encoder_dim=0,
 | 
			
		||||
                 ):
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -933,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        self.flag_causal_attention = False
 | 
			
		||||
        self.block_mask = None
 | 
			
		||||
        self.inject_sample_info = inject_sample_info
 | 
			
		||||
 | 
			
		||||
        self.motion_encoder_dim = motion_encoder_dim
 | 
			
		||||
        self.norm_output_audio = norm_output_audio
 | 
			
		||||
        self.audio_window = audio_window
 | 
			
		||||
        self.intermediate_dim = intermediate_dim
 | 
			
		||||
        self.vae_scale = vae_scale
 | 
			
		||||
 | 
			
		||||
        multitalk = multitalk_output_dim > 0
 | 
			
		||||
        self.multitalk = multitalk
 | 
			
		||||
        self.multitalk = multitalk 
 | 
			
		||||
        animate = motion_encoder_dim > 0
 | 
			
		||||
 | 
			
		||||
        # embeddings
 | 
			
		||||
        self.patch_embedding = nn.Conv3d(
 | 
			
		||||
@ -1038,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
 | 
			
		||||
                block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
 | 
			
		||||
 | 
			
		||||
        if animate:
 | 
			
		||||
            self.pose_patch_embedding = nn.Conv3d(
 | 
			
		||||
                16, dim, kernel_size=patch_size, stride=patch_size
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
 | 
			
		||||
            self.face_adapter = FaceAdapter(
 | 
			
		||||
                heads_num=self.num_heads,
 | 
			
		||||
                hidden_dim=self.dim,
 | 
			
		||||
                num_adapter_layers=self.num_layers // 5,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.face_encoder = FaceEncoder(
 | 
			
		||||
                in_dim=motion_encoder_dim,
 | 
			
		||||
                hidden_dim=self.dim,
 | 
			
		||||
                num_heads=4,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
 | 
			
		||||
        layer_list = [self.head, self.head.head, self.patch_embedding]
 | 
			
		||||
        target_dype= dtype
 | 
			
		||||
@ -1202,7 +1220,8 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        y=None,
 | 
			
		||||
        freqs = None,
 | 
			
		||||
        pipeline = None,
 | 
			
		||||
        current_step = 0,
 | 
			
		||||
        current_step_no = 0,
 | 
			
		||||
        real_step_no = 0,
 | 
			
		||||
        x_id= 0,
 | 
			
		||||
        max_steps = 0, 
 | 
			
		||||
        slg_layers=None,
 | 
			
		||||
@ -1219,6 +1238,9 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        ref_images_count = 0,
 | 
			
		||||
        standin_freqs = None,
 | 
			
		||||
        standin_ref = None,
 | 
			
		||||
        pose_latents=None, 
 | 
			
		||||
        face_pixel_values=None,
 | 
			
		||||
 | 
			
		||||
    ):
 | 
			
		||||
        # patch_dtype =  self.patch_embedding.weight.dtype
 | 
			
		||||
        modulation_dtype = self.time_projection[1].weight.dtype
 | 
			
		||||
@ -1251,9 +1273,18 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                    if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
 | 
			
		||||
                    x = torch.cat([x, y], dim=1)
 | 
			
		||||
                # embeddings
 | 
			
		||||
                # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
 | 
			
		||||
                x = self.patch_embedding(x).to(modulation_dtype)
 | 
			
		||||
                grid_sizes = x.shape[2:]
 | 
			
		||||
                x_list[i] = x
 | 
			
		||||
        y = None
 | 
			
		||||
        
 | 
			
		||||
        motion_vec_list = []
 | 
			
		||||
        for i, x in enumerate(x_list):
 | 
			
		||||
                # animate embeddings
 | 
			
		||||
                motion_vec = None
 | 
			
		||||
                if pose_latents is not None: 
 | 
			
		||||
                    x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
 | 
			
		||||
                motion_vec_list.append(motion_vec)
 | 
			
		||||
                if chipmunk:
 | 
			
		||||
                    x = x.unsqueeze(-1)
 | 
			
		||||
                    x_og_shape = x.shape
 | 
			
		||||
@ -1261,7 +1292,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                else:
 | 
			
		||||
                    x = x.flatten(2).transpose(1, 2)
 | 
			
		||||
                x_list[i] = x
 | 
			
		||||
        x, y = None, None
 | 
			
		||||
        x = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        block_mask = None
 | 
			
		||||
@ -1280,9 +1311,9 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
            del causal_mask
 | 
			
		||||
 | 
			
		||||
        offload.shared_state["embed_sizes"] = grid_sizes 
 | 
			
		||||
        offload.shared_state["step_no"] = current_step 
 | 
			
		||||
        offload.shared_state["step_no"] = real_step_no 
 | 
			
		||||
        offload.shared_state["max_steps"] = max_steps
 | 
			
		||||
        if current_step == 0 and x_id == 0: clear_caches()
 | 
			
		||||
        if current_step_no == 0 and x_id == 0: clear_caches()
 | 
			
		||||
        # arguments
 | 
			
		||||
 | 
			
		||||
        kwargs = dict(
 | 
			
		||||
@ -1306,7 +1337,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        if standin_ref is not None:
 | 
			
		||||
            standin_cache_enabled = False
 | 
			
		||||
            kwargs["standin_phase"] = 2
 | 
			
		||||
            if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
 | 
			
		||||
            if current_step_no == 0 or not standin_cache_enabled :
 | 
			
		||||
                standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
 | 
			
		||||
                standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
 | 
			
		||||
                standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
 | 
			
		||||
@ -1371,7 +1402,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        skips_steps_cache = self.cache
 | 
			
		||||
        if skips_steps_cache != None: 
 | 
			
		||||
            if skips_steps_cache.cache_type == "mag":
 | 
			
		||||
                if current_step <= skips_steps_cache.start_step:
 | 
			
		||||
                if real_step_no <= skips_steps_cache.start_step:
 | 
			
		||||
                    should_calc = True
 | 
			
		||||
                elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all
 | 
			
		||||
                    assert len(x_list) == 1
 | 
			
		||||
@ -1380,7 +1411,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                    x_should_calc = []
 | 
			
		||||
                    for i in range(1 if skips_steps_cache.one_for_all else len(x_list)):
 | 
			
		||||
                        cur_x_id = i if joint_pass else x_id  
 | 
			
		||||
                        cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list
 | 
			
		||||
                        cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list
 | 
			
		||||
                        skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step
 | 
			
		||||
                        skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1
 | 
			
		||||
                        cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps
 | 
			
		||||
@ -1400,7 +1431,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                if x_id != 0:
 | 
			
		||||
                    should_calc = skips_steps_cache.should_calc
 | 
			
		||||
                else:
 | 
			
		||||
                    if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1:
 | 
			
		||||
                    if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1:
 | 
			
		||||
                        should_calc = True
 | 
			
		||||
                        skips_steps_cache.accumulated_rel_l1_distance = 0
 | 
			
		||||
                    else:
 | 
			
		||||
@ -1453,7 +1484,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                    return [None] * len(x_list)
 | 
			
		||||
 | 
			
		||||
                if standin_x is not None:
 | 
			
		||||
                    if not standin_cache_enabled and x_id ==0 : get_cache("standin").clear()
 | 
			
		||||
                    if not standin_cache_enabled: get_cache("standin").clear()
 | 
			
		||||
                    standin_x = block(standin_x, context = None, grid_sizes = None, e= standin_e0, freqs = standin_freqs, standin_phase = 1)
 | 
			
		||||
 | 
			
		||||
                if slg_layers is not None and block_idx in slg_layers:
 | 
			
		||||
@ -1461,9 +1492,9 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                        continue
 | 
			
		||||
                    x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
 | 
			
		||||
                else:
 | 
			
		||||
                    for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
 | 
			
		||||
                    for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
 | 
			
		||||
                        if should_calc:
 | 
			
		||||
                            x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
 | 
			
		||||
                            x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0,  motion_vec = motion_vec,**kwargs)
 | 
			
		||||
                            del x
 | 
			
		||||
                    context = hints = audio_embedding  = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -221,13 +221,16 @@ class SingleStreamAttention(nn.Module):
 | 
			
		||||
        self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
        self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
 | 
			
		||||
    def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
 | 
			
		||||
        N_t, N_h, N_w = shape
 | 
			
		||||
 | 
			
		||||
        x = xlist[0]
 | 
			
		||||
        xlist.clear()
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x)
 | 
			
		||||
        del x
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim)
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
@ -247,9 +250,6 @@ class SingleStreamAttention(nn.Module):
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
 | 
			
		||||
        attn_bias = None
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
@ -302,7 +302,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, 
 | 
			
		||||
                x: torch.Tensor, 
 | 
			
		||||
                xlist: torch.Tensor, 
 | 
			
		||||
                encoder_hidden_states: torch.Tensor, 
 | 
			
		||||
                shape=None, 
 | 
			
		||||
                x_ref_attn_map=None,
 | 
			
		||||
@ -310,14 +310,17 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        
 | 
			
		||||
        encoder_hidden_states = encoder_hidden_states.squeeze(0)
 | 
			
		||||
        if x_ref_attn_map == None:
 | 
			
		||||
            return super().forward(x, encoder_hidden_states, shape)
 | 
			
		||||
            return super().forward(xlist, encoder_hidden_states, shape)
 | 
			
		||||
 | 
			
		||||
        N_t, _, _ = shape 
 | 
			
		||||
        x = xlist[0]
 | 
			
		||||
        xlist.clear()
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 
 | 
			
		||||
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x) 
 | 
			
		||||
        del x
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim) 
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
@ -339,7 +342,9 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        q = self.rope_1d(q, normalized_pos)
 | 
			
		||||
        qlist = [q]
 | 
			
		||||
        del q
 | 
			
		||||
        q = self.rope_1d(qlist, normalized_pos, "q")
 | 
			
		||||
        q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 | 
			
		||||
        _, N_a, _ = encoder_hidden_states.shape 
 | 
			
		||||
@ -347,7 +352,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
 | 
			
		||||
        encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
 | 
			
		||||
        encoder_k, encoder_v = encoder_kv.unbind(0) 
 | 
			
		||||
 | 
			
		||||
        del encoder_kv
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            encoder_k = self.add_k_norm(encoder_k)
 | 
			
		||||
 | 
			
		||||
@ -356,13 +361,14 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
 | 
			
		||||
        encoder_pos = torch.concat([per_frame]*N_t, dim=0)
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        encoder_k = self.rope_1d(encoder_k, encoder_pos)
 | 
			
		||||
        enclist = [encoder_k]
 | 
			
		||||
        del encoder_k
 | 
			
		||||
        encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160
 | 
			
		||||
 | 
			
		||||
    audio_emb = audio_emb.cpu().detach()
 | 
			
		||||
    return audio_emb
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
def extract_audio_from_video(filename, sample_rate):
 | 
			
		||||
    raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav'
 | 
			
		||||
    ffmpeg_command = [
 | 
			
		||||
        "ffmpeg",
 | 
			
		||||
        "-y",
 | 
			
		||||
        "-i",
 | 
			
		||||
        str(filename),
 | 
			
		||||
        "-vn",
 | 
			
		||||
        "-acodec",
 | 
			
		||||
        "pcm_s16le",
 | 
			
		||||
        "-ar",
 | 
			
		||||
        "16000",
 | 
			
		||||
        "-ac",
 | 
			
		||||
        "2",
 | 
			
		||||
        str(raw_audio_path),
 | 
			
		||||
    ]
 | 
			
		||||
    subprocess.run(ffmpeg_command, check=True)
 | 
			
		||||
    human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate)
 | 
			
		||||
    human_speech_array = loudness_norm(human_speech_array, sr)
 | 
			
		||||
    os.remove(raw_audio_path)
 | 
			
		||||
 | 
			
		||||
    return human_speech_array
 | 
			
		||||
 | 
			
		||||
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
 | 
			
		||||
    ext = os.path.splitext(audio_path)[1].lower()
 | 
			
		||||
    if ext in ['.mp4', '.mov', '.avi', '.mkv']:
 | 
			
		||||
@ -191,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2):
 | 
			
		||||
    return s1, s2, save_path_sum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames =  0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0):
 | 
			
		||||
def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames =  0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False):
 | 
			
		||||
    wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
 | 
			
		||||
    # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
 | 
			
		||||
    pad = int(padded_frames_for_embeddings/ fps * sr)
 | 
			
		||||
    new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
 | 
			
		||||
    audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
    audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
    full_audio_embs = []
 | 
			
		||||
    if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
 | 
			
		||||
    # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
 | 
			
		||||
    if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
 | 
			
		||||
    if audio_guide2 == None and not duration_changed: sum_human_speechs = None
 | 
			
		||||
    if return_sum_only:
 | 
			
		||||
        full_audio_embs = None
 | 
			
		||||
    else:
 | 
			
		||||
        audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
        audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
 | 
			
		||||
        full_audio_embs = []
 | 
			
		||||
        if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
 | 
			
		||||
        if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
 | 
			
		||||
        if audio_guide2 == None and not duration_changed: sum_human_speechs = None
 | 
			
		||||
    return full_audio_embs, sum_human_speechs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ import torchvision
 | 
			
		||||
import binascii
 | 
			
		||||
import os.path as osp
 | 
			
		||||
from skimage import color
 | 
			
		||||
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
 | 
			
		||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
 | 
			
		||||
ASPECT_RATIO_627 = {
 | 
			
		||||
@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.compile
 | 
			
		||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
 | 
			
		||||
    
 | 
			
		||||
    ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
 | 
			
		||||
def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None):
 | 
			
		||||
    dtype = visual_q.dtype
 | 
			
		||||
    ref_k = ref_k.to(dtype).to(visual_q.device)
 | 
			
		||||
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
			
		||||
    visual_q = visual_q * scale
 | 
			
		||||
    visual_q = visual_q.transpose(1, 2)
 | 
			
		||||
    ref_k = ref_k.transpose(1, 2)
 | 
			
		||||
    visual_q_shape = visual_q.shape
 | 
			
		||||
    visual_q = visual_q.view(-1, visual_q_shape[-1] )
 | 
			
		||||
    number_chunks = visual_q_shape[-2]*ref_k.shape[-2] /  53090100 * 2
 | 
			
		||||
    chunk_size =  int(visual_q_shape[-2] / number_chunks)
 | 
			
		||||
    chunks =torch.split(visual_q, chunk_size)
 | 
			
		||||
    maps_lists = [ [] for _ in ref_target_masks]  
 | 
			
		||||
    for q_chunk  in chunks:
 | 
			
		||||
        attn = q_chunk @ ref_k.transpose(-2, -1)
 | 
			
		||||
        x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
			
		||||
        del attn
 | 
			
		||||
        ref_target_masks = ref_target_masks.to(dtype)
 | 
			
		||||
        x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
 | 
			
		||||
 | 
			
		||||
        for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
			
		||||
            ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
			
		||||
            x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
			
		||||
            maps_lists[class_idx].append(x_ref_attnmap)
 | 
			
		||||
 | 
			
		||||
        del x_ref_attn_map_source
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_maps = []
 | 
			
		||||
    for class_idx, maps_list in enumerate(maps_lists):
 | 
			
		||||
        attn_map_fuse = torch.concat(maps_list, dim= -1)
 | 
			
		||||
        attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1)
 | 
			
		||||
        x_ref_attn_maps.append( attn_map_fuse )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
			
		||||
 | 
			
		||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count):
 | 
			
		||||
    dtype = visual_q.dtype
 | 
			
		||||
    ref_k = ref_k.to(dtype).to(visual_q.device)
 | 
			
		||||
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
			
		||||
    visual_q = visual_q * scale
 | 
			
		||||
    visual_q = visual_q.transpose(1, 2)
 | 
			
		||||
    ref_k = ref_k.transpose(1, 2)
 | 
			
		||||
    attn = visual_q @ ref_k.transpose(-2, -1)
 | 
			
		||||
 | 
			
		||||
    if attn_bias is not None: attn += attn_bias
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
			
		||||
 | 
			
		||||
    del attn
 | 
			
		||||
    x_ref_attn_maps = []
 | 
			
		||||
    ref_target_masks = ref_target_masks.to(visual_q.dtype)
 | 
			
		||||
    x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
 | 
			
		||||
    ref_target_masks = ref_target_masks.to(dtype)
 | 
			
		||||
    x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
 | 
			
		||||
 | 
			
		||||
    for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
			
		||||
        ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
			
		||||
        x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
 | 
			
		||||
       
 | 
			
		||||
        if mode == 'mean':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
 | 
			
		||||
        elif mode == 'max':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
 | 
			
		||||
        
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens       (mean of heads)
 | 
			
		||||
        x_ref_attn_maps.append(x_ref_attnmap)
 | 
			
		||||
    
 | 
			
		||||
    del attn
 | 
			
		||||
    del x_ref_attn_map_source
 | 
			
		||||
 | 
			
		||||
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
 | 
			
		||||
    """Args:
 | 
			
		||||
        query (torch.tensor): B M H K
 | 
			
		||||
@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
 | 
			
		||||
    N_t, N_h, N_w = shape
 | 
			
		||||
    
 | 
			
		||||
    x_seqlens = N_h * N_w
 | 
			
		||||
    if x_seqlens <= 1508:
 | 
			
		||||
        split_num = 10 # 540p
 | 
			
		||||
    else:
 | 
			
		||||
        split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p
 | 
			
		||||
 | 
			
		||||
    ref_k     = ref_k[:, :x_seqlens]
 | 
			
		||||
    if ref_images_count > 0 :
 | 
			
		||||
        visual_q_shape = visual_q.shape 
 | 
			
		||||
@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
 | 
			
		||||
 | 
			
		||||
    split_chunk = heads // split_num
 | 
			
		||||
    
 | 
			
		||||
    for i in range(split_num):
 | 
			
		||||
        x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
			
		||||
        x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    if split_chunk == 1:
 | 
			
		||||
        for i in range(split_num):
 | 
			
		||||
            x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count)
 | 
			
		||||
            x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    else:
 | 
			
		||||
        for i in range(split_num):
 | 
			
		||||
            x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
			
		||||
            x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    
 | 
			
		||||
    x_ref_attn_maps /= split_num
 | 
			
		||||
    return x_ref_attn_maps
 | 
			
		||||
@ -158,7 +196,6 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        self.base = 10000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @lru_cache(maxsize=32)
 | 
			
		||||
    def precompute_freqs_cis_1d(self, pos_indices):
 | 
			
		||||
 | 
			
		||||
        freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
 | 
			
		||||
@ -167,7 +204,7 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
 | 
			
		||||
        return freqs
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, pos_indices):
 | 
			
		||||
    def forward(self, qlist, pos_indices, cache_entry = None):
 | 
			
		||||
        """1D RoPE.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@ -176,16 +213,26 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        Returns:
 | 
			
		||||
            query with the same shape as input.
 | 
			
		||||
        """
 | 
			
		||||
        freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
 | 
			
		||||
 | 
			
		||||
        x_ = x.float()
 | 
			
		||||
 | 
			
		||||
        freqs_cis = freqs_cis.float().to(x.device)
 | 
			
		||||
        cos, sin = freqs_cis.cos(), freqs_cis.sin()
 | 
			
		||||
        cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
			
		||||
        x_ = (x_ * cos) + (rotate_half(x_) * sin)
 | 
			
		||||
 | 
			
		||||
        return x_.type_as(x)
 | 
			
		||||
        xq= qlist[0]
 | 
			
		||||
        qlist.clear()
 | 
			
		||||
        cache = get_cache("multitalk_rope")
 | 
			
		||||
        freqs_cis= cache.get(cache_entry, None)
 | 
			
		||||
        if freqs_cis is None:
 | 
			
		||||
            freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices)
 | 
			
		||||
        cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0)
 | 
			
		||||
        # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
			
		||||
        # real * cos - imag * sin
 | 
			
		||||
        # imag * cos + real * sin
 | 
			
		||||
        xq_dtype = xq.dtype
 | 
			
		||||
        xq_out = xq.to(torch.float)
 | 
			
		||||
        xq = None        
 | 
			
		||||
        xq_rot = rotate_half(xq_out)
 | 
			
		||||
        xq_out *= cos
 | 
			
		||||
        xq_rot *= sin
 | 
			
		||||
        xq_out += xq_rot
 | 
			
		||||
        del xq_rot
 | 
			
		||||
        xq_out = xq_out.to(xq_dtype)
 | 
			
		||||
        return xq_out 
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,698 +0,0 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import gc
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import sys
 | 
			
		||||
import types
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from functools import partial
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.cuda.amp as amp
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import torchvision.transforms.functional as TF
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from .distributed.fsdp import shard_model
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
from .modules.t5 import T5EncoderModel
 | 
			
		||||
from .modules.vae import WanVAE
 | 
			
		||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
 | 
			
		||||
                               get_sampling_sigmas, retrieve_timesteps)
 | 
			
		||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
from wan.modules.posemb_layers import get_rotary_pos_embed
 | 
			
		||||
from .utils.vace_preprocessor import VaceVideoProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimized_scale(positive_flat, negative_flat):
 | 
			
		||||
 | 
			
		||||
    # Calculate dot production
 | 
			
		||||
    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
    # Squared norm of uncondition
 | 
			
		||||
    squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
 | 
			
		||||
 | 
			
		||||
    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
 | 
			
		||||
    st_star = dot_product / squared_norm
 | 
			
		||||
    
 | 
			
		||||
    return st_star
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
class WanT2V:
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config,
 | 
			
		||||
        checkpoint_dir,
 | 
			
		||||
        rank=0,
 | 
			
		||||
        model_filename = None,
 | 
			
		||||
        text_encoder_filename = None,
 | 
			
		||||
        quantizeTransformer = False,
 | 
			
		||||
        dtype = torch.bfloat16
 | 
			
		||||
    ):
 | 
			
		||||
        self.device = torch.device(f"cuda")
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.rank = rank
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
			
		||||
        self.param_dtype = config.param_dtype
 | 
			
		||||
 | 
			
		||||
        self.text_encoder = T5EncoderModel(
 | 
			
		||||
            text_len=config.text_len,
 | 
			
		||||
            dtype=config.t5_dtype,
 | 
			
		||||
            device=torch.device('cpu'),
 | 
			
		||||
            checkpoint_path=text_encoder_filename,
 | 
			
		||||
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 | 
			
		||||
            shard_fn= None)
 | 
			
		||||
 | 
			
		||||
        self.vae_stride = config.vae_stride
 | 
			
		||||
        self.patch_size = config.patch_size 
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        self.vae = WanVAE(
 | 
			
		||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
			
		||||
            device=self.device)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Creating WanModel from {model_filename}")
 | 
			
		||||
        from mmgp import offload
 | 
			
		||||
 | 
			
		||||
        self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
 | 
			
		||||
        # offload.load_model_data(self.model, "recam.ckpt")
 | 
			
		||||
        # self.model.cpu()
 | 
			
		||||
        # offload.save_model(self.model, "recam.safetensors")
 | 
			
		||||
        if self.dtype == torch.float16 and not "fp16" in model_filename:
 | 
			
		||||
            self.model.to(self.dtype) 
 | 
			
		||||
        # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
 | 
			
		||||
        if self.dtype == torch.float16:
 | 
			
		||||
            self.vae.model.to(self.dtype)
 | 
			
		||||
        self.model.eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.sample_neg_prompt = config.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
        if "Vace" in model_filename:
 | 
			
		||||
            self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
 | 
			
		||||
                                            min_area=480*832,
 | 
			
		||||
                                            max_area=480*832,
 | 
			
		||||
                                            min_fps=config.sample_fps,
 | 
			
		||||
                                            max_fps=config.sample_fps,
 | 
			
		||||
                                            zero_start=True,
 | 
			
		||||
                                            seq_len=32760,
 | 
			
		||||
                                            keep_last=True)
 | 
			
		||||
 | 
			
		||||
            self.adapt_vace_model()
 | 
			
		||||
 | 
			
		||||
        self.scheduler = FlowUniPCMultistepScheduler()
 | 
			
		||||
 | 
			
		||||
    def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(frames)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(frames) == len(ref_images)
 | 
			
		||||
 | 
			
		||||
        if masks is None:
 | 
			
		||||
            latents = self.vae.encode(frames, tile_size = tile_size)
 | 
			
		||||
        else:
 | 
			
		||||
            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
 | 
			
		||||
            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
 | 
			
		||||
            inactive = self.vae.encode(inactive, tile_size = tile_size)
 | 
			
		||||
            reactive = self.vae.encode(reactive, tile_size = tile_size)
 | 
			
		||||
            latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
 | 
			
		||||
 | 
			
		||||
        cat_latents = []
 | 
			
		||||
        for latent, refs in zip(latents, ref_images):
 | 
			
		||||
            if refs is not None:
 | 
			
		||||
                if masks is None:
 | 
			
		||||
                    ref_latent = self.vae.encode(refs, tile_size = tile_size)
 | 
			
		||||
                else:
 | 
			
		||||
                    ref_latent = self.vae.encode(refs, tile_size = tile_size)
 | 
			
		||||
                    ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
 | 
			
		||||
                assert all([x.shape[1] == 1 for x in ref_latent])
 | 
			
		||||
                latent = torch.cat([*ref_latent, latent], dim=1)
 | 
			
		||||
            cat_latents.append(latent)
 | 
			
		||||
        return cat_latents
 | 
			
		||||
 | 
			
		||||
    def vace_encode_masks(self, masks, ref_images=None):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(masks)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(masks) == len(ref_images)
 | 
			
		||||
 | 
			
		||||
        result_masks = []
 | 
			
		||||
        for mask, refs in zip(masks, ref_images):
 | 
			
		||||
            c, depth, height, width = mask.shape
 | 
			
		||||
            new_depth = int((depth + 3) // self.vae_stride[0])
 | 
			
		||||
            height = 2 * (int(height) // (self.vae_stride[1] * 2))
 | 
			
		||||
            width = 2 * (int(width) // (self.vae_stride[2] * 2))
 | 
			
		||||
 | 
			
		||||
            # reshape
 | 
			
		||||
            mask = mask[0, :, :, :]
 | 
			
		||||
            mask = mask.view(
 | 
			
		||||
                depth, height, self.vae_stride[1], width, self.vae_stride[1]
 | 
			
		||||
            )  # depth, height, 8, width, 8
 | 
			
		||||
            mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
 | 
			
		||||
            mask = mask.reshape(
 | 
			
		||||
                self.vae_stride[1] * self.vae_stride[2], depth, height, width
 | 
			
		||||
            )  # 8*8, depth, height, width
 | 
			
		||||
 | 
			
		||||
            # interpolation
 | 
			
		||||
            mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
 | 
			
		||||
 | 
			
		||||
            if refs is not None:
 | 
			
		||||
                length = len(refs)
 | 
			
		||||
                mask_pad = torch.zeros_like(mask[:, :length, :, :])
 | 
			
		||||
                mask = torch.cat((mask_pad, mask), dim=1)
 | 
			
		||||
            result_masks.append(mask)
 | 
			
		||||
        return result_masks
 | 
			
		||||
 | 
			
		||||
    def vace_latent(self, z, m):
 | 
			
		||||
        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
 | 
			
		||||
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
 | 
			
		||||
        image_sizes = []
 | 
			
		||||
        trim_video = len(keep_frames)
 | 
			
		||||
 | 
			
		||||
        for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
 | 
			
		||||
            prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
 | 
			
		||||
            num_frames = total_frames - prepend_count 
 | 
			
		||||
            if sub_src_mask is not None and sub_src_video is not None:
 | 
			
		||||
                src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
 | 
			
		||||
                # src_video is [-1, 1], 0 = inpainting area (in fact 127  in [0, 255])
 | 
			
		||||
                # src_mask is [-1, 1], 0 = preserve original video (in fact 127  in [0, 255]) and 1 = Inpainting (in fact 255  in [0, 255])
 | 
			
		||||
                src_video[i] = src_video[i].to(device)
 | 
			
		||||
                src_mask[i] = src_mask[i].to(device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            elif sub_src_video is None:
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
 | 
			
		||||
                else:
 | 
			
		||||
                    src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
 | 
			
		||||
                    src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                image_sizes.append(image_size)
 | 
			
		||||
            else:
 | 
			
		||||
                src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
 | 
			
		||||
                src_video[i] = src_video[i].to(device)
 | 
			
		||||
                src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            for k, keep in enumerate(keep_frames):
 | 
			
		||||
                if not keep:
 | 
			
		||||
                    src_video[i][:, k:k+1] = 0
 | 
			
		||||
                    src_mask[i][:, k:k+1] = 1
 | 
			
		||||
 | 
			
		||||
        for i, ref_images in enumerate(src_ref_images):
 | 
			
		||||
            if ref_images is not None:
 | 
			
		||||
                image_size = image_sizes[i]
 | 
			
		||||
                for j, ref_img in enumerate(ref_images):
 | 
			
		||||
                    if ref_img is not None:
 | 
			
		||||
                        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
                        if ref_img.shape[-2:] != image_size:
 | 
			
		||||
                            canvas_height, canvas_width = image_size
 | 
			
		||||
                            ref_height, ref_width = ref_img.shape[-2:]
 | 
			
		||||
                            white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
 | 
			
		||||
                            scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
                            new_height = int(ref_height * scale)
 | 
			
		||||
                            new_width = int(ref_width * scale)
 | 
			
		||||
                            resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
 | 
			
		||||
                            top = (canvas_height - new_height) // 2
 | 
			
		||||
                            left = (canvas_width - new_width) // 2
 | 
			
		||||
                            white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
 | 
			
		||||
                            ref_img = white_canvas
 | 
			
		||||
                        src_ref_images[i][j] = ref_img.to(device)
 | 
			
		||||
        return src_video, src_mask, src_ref_images
 | 
			
		||||
 | 
			
		||||
    def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(zs)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(zs) == len(ref_images)
 | 
			
		||||
 | 
			
		||||
        trimed_zs = []
 | 
			
		||||
        for z, refs in zip(zs, ref_images):
 | 
			
		||||
            if refs is not None:
 | 
			
		||||
                z = z[:, len(refs):, :, :]
 | 
			
		||||
            trimed_zs.append(z)
 | 
			
		||||
 | 
			
		||||
        return self.vae.decode(trimed_zs, tile_size= tile_size)
 | 
			
		||||
 | 
			
		||||
    def generate_timestep_matrix(
 | 
			
		||||
        self,
 | 
			
		||||
        num_frames,
 | 
			
		||||
        step_template,
 | 
			
		||||
        base_num_frames,
 | 
			
		||||
        ar_step=5,
 | 
			
		||||
        num_pre_ready=0,
 | 
			
		||||
        casual_block_size=1,
 | 
			
		||||
        shrink_interval_with_mask=False,
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
 | 
			
		||||
        step_matrix, step_index = [], []
 | 
			
		||||
        update_mask, valid_interval = [], []
 | 
			
		||||
        num_iterations = len(step_template) + 1
 | 
			
		||||
        num_frames_block = num_frames // casual_block_size
 | 
			
		||||
        base_num_frames_block = base_num_frames // casual_block_size
 | 
			
		||||
        if base_num_frames_block < num_frames_block:
 | 
			
		||||
            infer_step_num = len(step_template)
 | 
			
		||||
            gen_block = base_num_frames_block
 | 
			
		||||
            min_ar_step = infer_step_num / gen_block
 | 
			
		||||
            assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
 | 
			
		||||
        # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
 | 
			
		||||
        step_template = torch.cat(
 | 
			
		||||
            [
 | 
			
		||||
                torch.tensor([999], dtype=torch.int64, device=step_template.device),
 | 
			
		||||
                step_template.long(),
 | 
			
		||||
                torch.tensor([0], dtype=torch.int64, device=step_template.device),
 | 
			
		||||
            ]
 | 
			
		||||
        )  # to handle the counter in row works starting from 1
 | 
			
		||||
        pre_row = torch.zeros(num_frames_block, dtype=torch.long)
 | 
			
		||||
        if num_pre_ready > 0:
 | 
			
		||||
            pre_row[: num_pre_ready // casual_block_size] = num_iterations
 | 
			
		||||
 | 
			
		||||
        while torch.all(pre_row >= (num_iterations - 1)) == False:
 | 
			
		||||
            new_row = torch.zeros(num_frames_block, dtype=torch.long)
 | 
			
		||||
            for i in range(num_frames_block):
 | 
			
		||||
                if i == 0 or pre_row[i - 1] >= (
 | 
			
		||||
                    num_iterations - 1
 | 
			
		||||
                ):  # the first frame or the last frame is completely denoised
 | 
			
		||||
                    new_row[i] = pre_row[i] + 1
 | 
			
		||||
                else:
 | 
			
		||||
                    new_row[i] = new_row[i - 1] - ar_step
 | 
			
		||||
            new_row = new_row.clamp(0, num_iterations)
 | 
			
		||||
 | 
			
		||||
            update_mask.append(
 | 
			
		||||
                (new_row != pre_row) & (new_row != num_iterations)
 | 
			
		||||
            )  # False: no need to update, True: need to update
 | 
			
		||||
            step_index.append(new_row)
 | 
			
		||||
            step_matrix.append(step_template[new_row])
 | 
			
		||||
            pre_row = new_row
 | 
			
		||||
 | 
			
		||||
        # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
 | 
			
		||||
        terminal_flag = base_num_frames_block
 | 
			
		||||
        if shrink_interval_with_mask:
 | 
			
		||||
            idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
 | 
			
		||||
            update_mask = update_mask[0]
 | 
			
		||||
            update_mask_idx = idx_sequence[update_mask]
 | 
			
		||||
            last_update_idx = update_mask_idx[-1].item()
 | 
			
		||||
            terminal_flag = last_update_idx + 1
 | 
			
		||||
        # for i in range(0, len(update_mask)):
 | 
			
		||||
        for curr_mask in update_mask:
 | 
			
		||||
            if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
 | 
			
		||||
                terminal_flag += 1
 | 
			
		||||
            valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
 | 
			
		||||
 | 
			
		||||
        step_update_mask = torch.stack(update_mask, dim=0)
 | 
			
		||||
        step_index = torch.stack(step_index, dim=0)
 | 
			
		||||
        step_matrix = torch.stack(step_matrix, dim=0)
 | 
			
		||||
 | 
			
		||||
        if casual_block_size > 1:
 | 
			
		||||
            step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
 | 
			
		||||
            valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
 | 
			
		||||
 | 
			
		||||
        return step_matrix, step_index, step_update_mask, valid_interval
 | 
			
		||||
    
 | 
			
		||||
    def generate(self,
 | 
			
		||||
                input_prompt,
 | 
			
		||||
                input_frames= None,
 | 
			
		||||
                input_masks = None,
 | 
			
		||||
                input_ref_images = None,      
 | 
			
		||||
                source_video=None,
 | 
			
		||||
                target_camera=None,                  
 | 
			
		||||
                context_scale=1.0,
 | 
			
		||||
                size=(1280, 720),
 | 
			
		||||
                frame_num=81,
 | 
			
		||||
                shift=5.0,
 | 
			
		||||
                sample_solver='unipc',
 | 
			
		||||
                sampling_steps=50,
 | 
			
		||||
                guide_scale=5.0,
 | 
			
		||||
                n_prompt="",
 | 
			
		||||
                seed=-1,
 | 
			
		||||
                offload_model=True,
 | 
			
		||||
                callback = None,
 | 
			
		||||
                enable_RIFLEx = None,
 | 
			
		||||
                VAE_tile_size = 0,
 | 
			
		||||
                joint_pass = False,
 | 
			
		||||
                slg_layers = None,
 | 
			
		||||
                slg_start = 0.0,
 | 
			
		||||
                slg_end = 1.0,
 | 
			
		||||
                cfg_star_switch = True,
 | 
			
		||||
                cfg_zero_step = 5,
 | 
			
		||||
                 ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Generates video frames from text prompt using diffusion process.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input_prompt (`str`):
 | 
			
		||||
                Text prompt for content generation
 | 
			
		||||
            size (tupele[`int`], *optional*, defaults to (1280,720)):
 | 
			
		||||
                Controls video resolution, (width,height).
 | 
			
		||||
            frame_num (`int`, *optional*, defaults to 81):
 | 
			
		||||
                How many frames to sample from a video. The number should be 4n+1
 | 
			
		||||
            shift (`float`, *optional*, defaults to 5.0):
 | 
			
		||||
                Noise schedule shift parameter. Affects temporal dynamics
 | 
			
		||||
            sample_solver (`str`, *optional*, defaults to 'unipc'):
 | 
			
		||||
                Solver used to sample the video.
 | 
			
		||||
            sampling_steps (`int`, *optional*, defaults to 40):
 | 
			
		||||
                Number of diffusion sampling steps. Higher values improve quality but slow generation
 | 
			
		||||
            guide_scale (`float`, *optional*, defaults 5.0):
 | 
			
		||||
                Classifier-free guidance scale. Controls prompt adherence vs. creativity
 | 
			
		||||
            n_prompt (`str`, *optional*, defaults to ""):
 | 
			
		||||
                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
 | 
			
		||||
            seed (`int`, *optional*, defaults to -1):
 | 
			
		||||
                Random seed for noise generation. If -1, use random seed.
 | 
			
		||||
            offload_model (`bool`, *optional*, defaults to True):
 | 
			
		||||
                If True, offloads models to CPU during generation to save VRAM
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor:
 | 
			
		||||
                Generated video frames tensor. Dimensions: (C, N H, W) where:
 | 
			
		||||
                - C: Color channels (3 for RGB)
 | 
			
		||||
                - N: Number of frames (81)
 | 
			
		||||
                - H: Frame height (from size)
 | 
			
		||||
                - W: Frame width from size)
 | 
			
		||||
        """
 | 
			
		||||
        # preprocess
 | 
			
		||||
 | 
			
		||||
        if n_prompt == "":
 | 
			
		||||
            n_prompt = self.sample_neg_prompt
 | 
			
		||||
        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
 | 
			
		||||
        seed_g = torch.Generator(device=self.device)
 | 
			
		||||
        seed_g.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
        frame_num = max(17, frame_num) # must match causal_block_size for value of 5
 | 
			
		||||
        frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
 | 
			
		||||
        num_frames = frame_num
 | 
			
		||||
        addnoise_condition = 20
 | 
			
		||||
        causal_attention = True
 | 
			
		||||
        fps = 16
 | 
			
		||||
        ar_step = 5
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        context = self.text_encoder([input_prompt], self.device)
 | 
			
		||||
        context_null = self.text_encoder([n_prompt], self.device)
 | 
			
		||||
        if target_camera != None:
 | 
			
		||||
            size = (source_video.shape[2], source_video.shape[1])
 | 
			
		||||
            source_video = source_video.to(dtype=self.dtype , device=self.device)
 | 
			
		||||
            source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)            
 | 
			
		||||
            source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
            del source_video
 | 
			
		||||
            # Process target camera (recammaster)
 | 
			
		||||
            from wan.utils.cammmaster_tools import get_camera_embedding
 | 
			
		||||
            cam_emb = get_camera_embedding(target_camera)       
 | 
			
		||||
            cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
 | 
			
		||||
        if input_frames != None:
 | 
			
		||||
            # vace context encode
 | 
			
		||||
            input_frames = [u.to(self.device) for u in input_frames]
 | 
			
		||||
            input_ref_images = [ None if u == None else [v.to(self.device) for v in u]  for u in input_ref_images]
 | 
			
		||||
            input_masks = [u.to(self.device) for u in input_masks]
 | 
			
		||||
 | 
			
		||||
            z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
 | 
			
		||||
            m0 = self.vace_encode_masks(input_masks, input_ref_images)
 | 
			
		||||
            z = self.vace_latent(z0, m0)
 | 
			
		||||
 | 
			
		||||
            target_shape = list(z0[0].shape)
 | 
			
		||||
            target_shape[0] = int(target_shape[0] / 2)
 | 
			
		||||
        else:
 | 
			
		||||
            F = frame_num
 | 
			
		||||
            target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
 | 
			
		||||
                            size[1] // self.vae_stride[1],
 | 
			
		||||
                            size[0] // self.vae_stride[2])
 | 
			
		||||
 | 
			
		||||
        seq_len = math.ceil((target_shape[2] * target_shape[3]) /
 | 
			
		||||
                            (self.patch_size[1] * self.patch_size[2]) *
 | 
			
		||||
                            target_shape[1]) 
 | 
			
		||||
 | 
			
		||||
        context  = [u.to(self.dtype) for u in context]
 | 
			
		||||
        context_null  = [u.to(self.dtype) for u in context_null]
 | 
			
		||||
 | 
			
		||||
        noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
 | 
			
		||||
 | 
			
		||||
        # evaluation mode
 | 
			
		||||
 | 
			
		||||
        # if sample_solver == 'unipc':
 | 
			
		||||
        #     sample_scheduler = FlowUniPCMultistepScheduler(
 | 
			
		||||
        #         num_train_timesteps=self.num_train_timesteps,
 | 
			
		||||
        #         shift=1,
 | 
			
		||||
        #         use_dynamic_shifting=False)
 | 
			
		||||
        #     sample_scheduler.set_timesteps(
 | 
			
		||||
        #         sampling_steps, device=self.device, shift=shift)
 | 
			
		||||
        #     timesteps = sample_scheduler.timesteps
 | 
			
		||||
        # elif sample_solver == 'dpm++':
 | 
			
		||||
        #     sample_scheduler = FlowDPMSolverMultistepScheduler(
 | 
			
		||||
        #         num_train_timesteps=self.num_train_timesteps,
 | 
			
		||||
        #         shift=1,
 | 
			
		||||
        #         use_dynamic_shifting=False)
 | 
			
		||||
        #     sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
 | 
			
		||||
        #     timesteps, _ = retrieve_timesteps(
 | 
			
		||||
        #         sample_scheduler,
 | 
			
		||||
        #         device=self.device,
 | 
			
		||||
        #         sigmas=sampling_sigmas)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     raise NotImplementedError("Unsupported solver.")
 | 
			
		||||
 | 
			
		||||
        # sample videos
 | 
			
		||||
        latents = noise
 | 
			
		||||
        del noise
 | 
			
		||||
        batch_size =len(latents)
 | 
			
		||||
        if target_camera != None:
 | 
			
		||||
            shape = list(latents[0].shape[1:])
 | 
			
		||||
            shape[0] *= 2
 | 
			
		||||
            freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) 
 | 
			
		||||
        else:
 | 
			
		||||
            freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) 
 | 
			
		||||
        # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
 | 
			
		||||
        # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
 | 
			
		||||
        # arg_both = {'context': context, 'context2': context_null,  'freqs': freqs, 'pipeline': self, 'callback': callback}
 | 
			
		||||
 | 
			
		||||
        i2v_extra_kwrags = {}
 | 
			
		||||
 | 
			
		||||
        if target_camera != None:
 | 
			
		||||
            recam_dict = {'cam_emb': cam_emb}
 | 
			
		||||
            i2v_extra_kwrags.update(recam_dict)
 | 
			
		||||
 | 
			
		||||
        if input_frames != None:
 | 
			
		||||
            vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
 | 
			
		||||
            i2v_extra_kwrags.update(vace_dict)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        latent_length = (num_frames - 1) // 4 + 1
 | 
			
		||||
        latent_height = height // 8
 | 
			
		||||
        latent_width = width // 8
 | 
			
		||||
        if ar_step == 0: 
 | 
			
		||||
            causal_block_size = 1
 | 
			
		||||
        fps_embeds = [fps] #* prompt_embeds[0].shape[0]
 | 
			
		||||
        fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
 | 
			
		||||
 | 
			
		||||
        self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
 | 
			
		||||
        init_timesteps = self.scheduler.timesteps
 | 
			
		||||
        base_num_frames_iter = latent_length
 | 
			
		||||
        latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
 | 
			
		||||
 | 
			
		||||
        prefix_video = None
 | 
			
		||||
        predix_video_latent_length = 0
 | 
			
		||||
 | 
			
		||||
        if prefix_video is not None:
 | 
			
		||||
            latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
 | 
			
		||||
        step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
 | 
			
		||||
            base_num_frames_iter,
 | 
			
		||||
            init_timesteps,
 | 
			
		||||
            base_num_frames_iter,
 | 
			
		||||
            ar_step,
 | 
			
		||||
            predix_video_latent_length,
 | 
			
		||||
            causal_block_size,
 | 
			
		||||
        )
 | 
			
		||||
        sample_schedulers = []
 | 
			
		||||
        for _ in range(base_num_frames_iter):
 | 
			
		||||
            sample_scheduler = FlowUniPCMultistepScheduler(
 | 
			
		||||
                num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
 | 
			
		||||
            )
 | 
			
		||||
            sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
 | 
			
		||||
            sample_schedulers.append(sample_scheduler)
 | 
			
		||||
        sample_schedulers_counter = [0] * base_num_frames_iter
 | 
			
		||||
 | 
			
		||||
        updated_num_steps=  len(step_matrix)
 | 
			
		||||
 | 
			
		||||
        if callback != None:
 | 
			
		||||
            callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
			
		||||
        if self.model.enable_teacache:
 | 
			
		||||
            self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
 | 
			
		||||
        # if callback != None:
 | 
			
		||||
        #     callback(-1, None, True)
 | 
			
		||||
 | 
			
		||||
        for i, timestep_i in enumerate(tqdm(step_matrix)):
 | 
			
		||||
            update_mask_i = step_update_mask[i]
 | 
			
		||||
            valid_interval_i = valid_interval[i]
 | 
			
		||||
            valid_interval_start, valid_interval_end = valid_interval_i
 | 
			
		||||
            timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
 | 
			
		||||
            latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
 | 
			
		||||
            if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
 | 
			
		||||
                noise_factor = 0.001 * addnoise_condition
 | 
			
		||||
                timestep_for_noised_condition = addnoise_condition
 | 
			
		||||
                latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
 | 
			
		||||
                    latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
 | 
			
		||||
                    * (1.0 - noise_factor)
 | 
			
		||||
                    + torch.randn_like(
 | 
			
		||||
                        latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
 | 
			
		||||
                    )
 | 
			
		||||
                    * noise_factor
 | 
			
		||||
                )
 | 
			
		||||
                timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
 | 
			
		||||
            kwrags = {
 | 
			
		||||
                "x" : torch.stack([latent_model_input[0]]),
 | 
			
		||||
                "t" : timestep,
 | 
			
		||||
                "freqs" :freqs,
 | 
			
		||||
                "fps" : fps_embeds,
 | 
			
		||||
                "causal_block_size" : causal_block_size,
 | 
			
		||||
                "causal_attention" : causal_attention,
 | 
			
		||||
                "callback" : callback,
 | 
			
		||||
                "pipeline" : self,
 | 
			
		||||
                "current_step" : i,                 
 | 
			
		||||
            }   
 | 
			
		||||
            kwrags.update(i2v_extra_kwrags)
 | 
			
		||||
                
 | 
			
		||||
            if not self.do_classifier_free_guidance:
 | 
			
		||||
                noise_pred = self.model(
 | 
			
		||||
                    context=context,
 | 
			
		||||
                    **kwrags,
 | 
			
		||||
                )[0]
 | 
			
		||||
                if self._interrupt:
 | 
			
		||||
                    return None
 | 
			
		||||
                noise_pred= noise_pred.to(torch.float32)                                                                  
 | 
			
		||||
            else:
 | 
			
		||||
                if joint_pass:
 | 
			
		||||
                    noise_pred_cond, noise_pred_uncond = self.model(
 | 
			
		||||
                        context=context,
 | 
			
		||||
                        context2=context_null,
 | 
			
		||||
                        **kwrags,
 | 
			
		||||
                    )
 | 
			
		||||
                    if self._interrupt:
 | 
			
		||||
                        return None                
 | 
			
		||||
                else:
 | 
			
		||||
                    noise_pred_cond = self.model(
 | 
			
		||||
                        context=context,
 | 
			
		||||
                        **kwrags,
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    if self._interrupt:
 | 
			
		||||
                        return None                
 | 
			
		||||
                    noise_pred_uncond = self.model(
 | 
			
		||||
                        context=context_null,
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    if self._interrupt:
 | 
			
		||||
                        return None
 | 
			
		||||
                noise_pred_cond= noise_pred_cond.to(torch.float32)                                          
 | 
			
		||||
                noise_pred_uncond= noise_pred_uncond.to(torch.float32)                                          
 | 
			
		||||
                noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
 | 
			
		||||
                del noise_pred_cond, noise_pred_uncond
 | 
			
		||||
            for idx in range(valid_interval_start, valid_interval_end):
 | 
			
		||||
                if update_mask_i[idx].item():
 | 
			
		||||
                    latents[0][:, idx] = sample_schedulers[idx].step(
 | 
			
		||||
                        noise_pred[:, idx - valid_interval_start],
 | 
			
		||||
                        timestep_i[idx],
 | 
			
		||||
                        latents[0][:, idx],
 | 
			
		||||
                        return_dict=False,
 | 
			
		||||
                        generator=seed_g,
 | 
			
		||||
                    )[0]
 | 
			
		||||
                    sample_schedulers_counter[idx] += 1
 | 
			
		||||
            if callback is not None:
 | 
			
		||||
                callback(i, latents[0].squeeze(0), False)         
 | 
			
		||||
 | 
			
		||||
        # for i, t in enumerate(tqdm(timesteps)):
 | 
			
		||||
        #     if target_camera != None:
 | 
			
		||||
        #         latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         latent_model_input = latents
 | 
			
		||||
        #     slg_layers_local = None
 | 
			
		||||
        #     if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
 | 
			
		||||
        #         slg_layers_local = slg_layers
 | 
			
		||||
        #     timestep = [t]
 | 
			
		||||
        #     offload.set_step_no_for_lora(self.model, i)
 | 
			
		||||
        #     timestep = torch.stack(timestep)
 | 
			
		||||
 | 
			
		||||
        #     if joint_pass:
 | 
			
		||||
        #         noise_pred_cond, noise_pred_uncond = self.model(
 | 
			
		||||
        #             latent_model_input, t=timestep,  current_step=i, slg_layers=slg_layers_local, **arg_both)
 | 
			
		||||
        #         if self._interrupt:
 | 
			
		||||
        #             return None
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         noise_pred_cond = self.model(
 | 
			
		||||
        #             latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
 | 
			
		||||
        #         if self._interrupt:
 | 
			
		||||
        #             return None               
 | 
			
		||||
        #         noise_pred_uncond = self.model(
 | 
			
		||||
        #             latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
 | 
			
		||||
        #         if self._interrupt:
 | 
			
		||||
        #             return None
 | 
			
		||||
 | 
			
		||||
        #     # del latent_model_input
 | 
			
		||||
 | 
			
		||||
        #     # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
 | 
			
		||||
        #     noise_pred_text = noise_pred_cond
 | 
			
		||||
        #     if cfg_star_switch:
 | 
			
		||||
        #         positive_flat = noise_pred_text.view(batch_size, -1)  
 | 
			
		||||
        #         negative_flat = noise_pred_uncond.view(batch_size, -1)  
 | 
			
		||||
 | 
			
		||||
        #         alpha = optimized_scale(positive_flat,negative_flat)
 | 
			
		||||
        #         alpha = alpha.view(batch_size, 1, 1, 1)
 | 
			
		||||
 | 
			
		||||
        #         if (i <= cfg_zero_step):
 | 
			
		||||
        #             noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
 | 
			
		||||
        #         else:
 | 
			
		||||
        #             noise_pred_uncond *= alpha
 | 
			
		||||
        #     noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)            
 | 
			
		||||
        #     del noise_pred_uncond
 | 
			
		||||
 | 
			
		||||
        #     temp_x0 = sample_scheduler.step(
 | 
			
		||||
        #         noise_pred[:, :target_shape[1]].unsqueeze(0),
 | 
			
		||||
        #         t,
 | 
			
		||||
        #         latents[0].unsqueeze(0),
 | 
			
		||||
        #         return_dict=False,
 | 
			
		||||
        #         generator=seed_g)[0]
 | 
			
		||||
        #     latents = [temp_x0.squeeze(0)]
 | 
			
		||||
        #     del temp_x0
 | 
			
		||||
 | 
			
		||||
        #     if callback is not None:
 | 
			
		||||
        #         callback(i, latents[0], False)         
 | 
			
		||||
 | 
			
		||||
        x0 = latents
 | 
			
		||||
 | 
			
		||||
        if input_frames == None:
 | 
			
		||||
            videos = self.vae.decode(x0, VAE_tile_size)
 | 
			
		||||
        else:
 | 
			
		||||
            videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
 | 
			
		||||
 | 
			
		||||
        del latents
 | 
			
		||||
        del sample_scheduler
 | 
			
		||||
 | 
			
		||||
        return videos[0] if self.rank == 0 else None
 | 
			
		||||
 | 
			
		||||
    def adapt_vace_model(self):
 | 
			
		||||
        model = self.model
 | 
			
		||||
        modules_dict= { k: m for k, m in model.named_modules()}
 | 
			
		||||
        for model_layer, vace_layer in model.vace_layers_mapping.items():
 | 
			
		||||
            module = modules_dict[f"vace_blocks.{vace_layer}"]
 | 
			
		||||
            target = modules_dict[f"blocks.{model_layer}"]
 | 
			
		||||
            setattr(target, "vace", module )
 | 
			
		||||
        delattr(model, "vace_blocks")
 | 
			
		||||
                    
 | 
			
		||||
 
 | 
			
		||||
@ -1,8 +1,12 @@
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
def test_class_i2v(base_model_type):    
 | 
			
		||||
    return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p",  "fantasy",  "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
 | 
			
		||||
    return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p",  "fantasy",  "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
 | 
			
		||||
 | 
			
		||||
def text_oneframe_overlap(base_model_type):
 | 
			
		||||
    return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
 | 
			
		||||
 | 
			
		||||
def test_class_1_3B(base_model_type):    
 | 
			
		||||
    return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
 | 
			
		||||
@ -13,6 +17,8 @@ def test_multitalk(base_model_type):
 | 
			
		||||
def test_standin(base_model_type):
 | 
			
		||||
    return base_model_type in ["standin", "vace_standin_14B"]
 | 
			
		||||
 | 
			
		||||
def test_wan_5B(base_model_type):
 | 
			
		||||
    return base_model_type in ["ti2v_2_2", "lucy_edit"]
 | 
			
		||||
class family_handler():
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -32,7 +38,7 @@ class family_handler():
 | 
			
		||||
                def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
 | 
			
		||||
            elif base_model_type in ["i2v_2_2"]:
 | 
			
		||||
                def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
 | 
			
		||||
            elif base_model_type in ["ti2v_2_2"]:
 | 
			
		||||
            elif test_wan_5B(base_model_type):
 | 
			
		||||
                if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
 | 
			
		||||
                    def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
 | 
			
		||||
                else: # i2v
 | 
			
		||||
@ -79,11 +85,13 @@ class family_handler():
 | 
			
		||||
        vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] 
 | 
			
		||||
        extra_model_def["vace_class"] = vace_class
 | 
			
		||||
 | 
			
		||||
        if test_multitalk(base_model_type):
 | 
			
		||||
        if base_model_type in ["animate"]:
 | 
			
		||||
            fps = 30
 | 
			
		||||
        elif test_multitalk(base_model_type):
 | 
			
		||||
            fps = 25
 | 
			
		||||
        elif base_model_type in ["fantasy"]:
 | 
			
		||||
            fps = 23
 | 
			
		||||
        elif base_model_type in ["ti2v_2_2"]:
 | 
			
		||||
        elif test_wan_5B(base_model_type):
 | 
			
		||||
            fps = 24
 | 
			
		||||
        else:
 | 
			
		||||
            fps = 16
 | 
			
		||||
@ -96,14 +104,14 @@ class family_handler():
 | 
			
		||||
        extra_model_def.update({
 | 
			
		||||
        "frames_minimum" : frames_minimum,
 | 
			
		||||
        "frames_steps" : frames_steps, 
 | 
			
		||||
        "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class,  #"ti2v_2_2",
 | 
			
		||||
        "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class,  #"ti2v_2_2",
 | 
			
		||||
        "multiple_submodels" : multiple_submodels,
 | 
			
		||||
        "guidance_max_phases" : 3,
 | 
			
		||||
        "skip_layer_guidance" : True,        
 | 
			
		||||
        "cfg_zero" : True,
 | 
			
		||||
        "cfg_star" : True,
 | 
			
		||||
        "adaptive_projected_guidance" : True,  
 | 
			
		||||
        "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
 | 
			
		||||
        "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
 | 
			
		||||
        "mag_cache" : True,
 | 
			
		||||
        "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
 | 
			
		||||
        "sample_solvers":[
 | 
			
		||||
@ -112,9 +120,169 @@ class family_handler():
 | 
			
		||||
                            ("dpm++", "dpm++"),
 | 
			
		||||
                            ("flowmatch causvid", "causvid"), ]
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["t2v"]: 
 | 
			
		||||
            extra_model_def["guide_custom_choices"] = {
 | 
			
		||||
                "choices":[("Use Text Prompt Only", ""),
 | 
			
		||||
                           ("Video to Video guided by Text Prompt", "GUV"),
 | 
			
		||||
                           ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")],
 | 
			
		||||
                "default": "",
 | 
			
		||||
                "letters_filter": "GUVA",
 | 
			
		||||
                "label": "Video to Video"
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                "selection":[ "", "A"],
 | 
			
		||||
                "visible": False
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["infinitetalk"]: 
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
            extra_model_def["all_image_refs_are_background_ref"] = True
 | 
			
		||||
            extra_model_def["guide_custom_choices"] = {
 | 
			
		||||
            "choices":[
 | 
			
		||||
                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
 | 
			
		||||
                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
 | 
			
		||||
                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
 | 
			
		||||
                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
 | 
			
		||||
                ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
 | 
			
		||||
                ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
 | 
			
		||||
            ],
 | 
			
		||||
            "default": "KI",
 | 
			
		||||
            "letters_filter": "RGUVQKI",
 | 
			
		||||
            "label": "Video to Video",
 | 
			
		||||
            "show_label" : False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            # extra_model_def["at_least_one_image_ref_needed"] = True
 | 
			
		||||
        if base_model_type in ["lucy_edit"]:
 | 
			
		||||
            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["UV"],
 | 
			
		||||
                    "labels" : { "UV": "Control Video"},
 | 
			
		||||
                    "visible": False,
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["animate"]:
 | 
			
		||||
            extra_model_def["guide_custom_choices"] = {
 | 
			
		||||
            "choices":[
 | 
			
		||||
                ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"),
 | 
			
		||||
                ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"),
 | 
			
		||||
                ("Replace Person in Control Video by Person in Reference Image", "PVBAI"),
 | 
			
		||||
                ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"),
 | 
			
		||||
            ],
 | 
			
		||||
            "default": "PVBKI",
 | 
			
		||||
            "letters_filter": "PVBXAKI1",
 | 
			
		||||
            "label": "Type of Process",
 | 
			
		||||
            "show_label" : False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                "selection":[ "", "A", "XA"],
 | 
			
		||||
                "visible": False
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["video_guide_outpainting"] = [0,1]
 | 
			
		||||
            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
			
		||||
            extra_model_def["extract_guide_from_window_start"] = True
 | 
			
		||||
            extra_model_def["forced_guide_mask_inputs"] = True
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
 | 
			
		||||
            extra_model_def["background_ref_outpainted"] = False
 | 
			
		||||
            extra_model_def["return_image_refs_tensor"] = True
 | 
			
		||||
            extra_model_def["guide_inpaint_color"] = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if vace_class:
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
 | 
			
		||||
                    "labels" : { "V": "Use Vace raw format"}
 | 
			
		||||
                }
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                    "choices": [("None", ""),
 | 
			
		||||
                    ("People / Objects", "I"),
 | 
			
		||||
                    ("Landscape followed by People / Objects (if any)", "KI"),
 | 
			
		||||
                    ("Positioned Frames followed by People / Objects (if any)", "FI"),
 | 
			
		||||
                    ],
 | 
			
		||||
                    "letters_filter":  "KFI",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
 | 
			
		||||
            extra_model_def["video_guide_outpainting"] = [0,1]
 | 
			
		||||
            extra_model_def["pad_guide_video"] = True
 | 
			
		||||
            extra_model_def["guide_inpaint_color"] = 127.5
 | 
			
		||||
            extra_model_def["forced_guide_mask_inputs"] = True
 | 
			
		||||
            extra_model_def["return_image_refs_tensor"] = True
 | 
			
		||||
            
 | 
			
		||||
        if base_model_type in ["standin"]: 
 | 
			
		||||
            extra_model_def["fit_into_canvas_image_refs"] = 0
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("No Reference Image", ""),
 | 
			
		||||
                    ("Reference Image is a Person Face", "I"),
 | 
			
		||||
                    ],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["phantom_1.3B", "phantom_14B"]: 
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [("Reference Image", "I")],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
                "visible": False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["recam_1.3B"]: 
 | 
			
		||||
            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
			
		||||
            extra_model_def["model_modes"] = {
 | 
			
		||||
                        "choices": [
 | 
			
		||||
                            ("Pan Right", 1),
 | 
			
		||||
                            ("Pan Left", 2),
 | 
			
		||||
                            ("Tilt Up", 3),
 | 
			
		||||
                            ("Tilt Down", 4),
 | 
			
		||||
                            ("Zoom In", 5),
 | 
			
		||||
                            ("Zoom Out", 6),
 | 
			
		||||
                            ("Translate Up (with rotation)", 7),
 | 
			
		||||
                            ("Translate Down (with rotation)", 8),
 | 
			
		||||
                            ("Arc Left (with rotation)", 9),
 | 
			
		||||
                            ("Arc Right (with rotation)", 10),
 | 
			
		||||
                        ],
 | 
			
		||||
                        "default": 1,
 | 
			
		||||
                        "label" : "Camera Movement Type"
 | 
			
		||||
            }
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["UV"],
 | 
			
		||||
                    "labels" : { "UV": "Control Video"},
 | 
			
		||||
                    "visible" : False,
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
        if vace_class or base_model_type in ["animate"]:
 | 
			
		||||
            image_prompt_types_allowed = "TVL"
 | 
			
		||||
        elif base_model_type in ["infinitetalk"]:
 | 
			
		||||
            image_prompt_types_allowed = "TSVL"
 | 
			
		||||
        elif base_model_type in ["ti2v_2_2"]:
 | 
			
		||||
            image_prompt_types_allowed = "TSVL"
 | 
			
		||||
        elif base_model_type in ["lucy_edit"]:
 | 
			
		||||
            image_prompt_types_allowed = "TVL"
 | 
			
		||||
        elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
 | 
			
		||||
            image_prompt_types_allowed = "SVL"
 | 
			
		||||
        elif i2v:
 | 
			
		||||
            image_prompt_types_allowed = "SEVL"
 | 
			
		||||
        else:
 | 
			
		||||
            image_prompt_types_allowed = ""
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed
 | 
			
		||||
 | 
			
		||||
        if text_oneframe_overlap(base_model_type):
 | 
			
		||||
            extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1}
 | 
			
		||||
 | 
			
		||||
        # if base_model_type in ["phantom_1.3B", "phantom_14B"]: 
 | 
			
		||||
        #     extra_model_def["one_image_ref_needed"] = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
        
 | 
			
		||||
@ -122,8 +290,8 @@ class family_handler():
 | 
			
		||||
    def query_supported_types():
 | 
			
		||||
        return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
 | 
			
		||||
                    "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", 
 | 
			
		||||
                    "recam_1.3B", 
 | 
			
		||||
                    "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
 | 
			
		||||
                    "recam_1.3B", "animate",
 | 
			
		||||
                    "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -153,11 +321,12 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_vae_block_size(base_model_type):
 | 
			
		||||
        return 32 if base_model_type == "ti2v_2_2" else 16
 | 
			
		||||
        return 32 if test_wan_5B(base_model_type) else 16
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_rgb_factors(base_model_type ):
 | 
			
		||||
        from shared.RGB_factors import get_rgb_factors
 | 
			
		||||
        if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
 | 
			
		||||
        latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
 | 
			
		||||
        return latent_rgb_factors, latent_rgb_factors_bias
 | 
			
		||||
    
 | 
			
		||||
@ -171,7 +340,7 @@ class family_handler():
 | 
			
		||||
            "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors",  "fantasy_proj_model.safetensors" ] +  computeList(model_filename)  ]   
 | 
			
		||||
        }]
 | 
			
		||||
 | 
			
		||||
        if base_model_type == "ti2v_2_2":
 | 
			
		||||
        if test_wan_5B(base_model_type):
 | 
			
		||||
            download_def += [    {
 | 
			
		||||
                "repoId" : "DeepBeepMeep/Wan2.2", 
 | 
			
		||||
                "sourceFolderList" :  [""],
 | 
			
		||||
@ -182,7 +351,7 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
 | 
			
		||||
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
 | 
			
		||||
        from .configs import WAN_CONFIGS
 | 
			
		||||
 | 
			
		||||
        if test_class_i2v(base_model_type):
 | 
			
		||||
@ -195,6 +364,7 @@ class family_handler():
 | 
			
		||||
            config=cfg,
 | 
			
		||||
            checkpoint_dir="ckpts",
 | 
			
		||||
            model_filename=model_filename,
 | 
			
		||||
            submodel_no_list = submodel_no_list,
 | 
			
		||||
            model_type = model_type,        
 | 
			
		||||
            model_def = model_def,
 | 
			
		||||
            base_model_type=base_model_type,
 | 
			
		||||
@ -235,15 +405,42 @@ class family_handler():
 | 
			
		||||
                if "I" in video_prompt_type:
 | 
			
		||||
                    video_prompt_type = video_prompt_type.replace("KI", "QKI")
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.28:
 | 
			
		||||
            if base_model_type in "infinitetalk":
 | 
			
		||||
                video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
			
		||||
                if "U" in video_prompt_type:
 | 
			
		||||
                    video_prompt_type = video_prompt_type.replace("U", "RU")
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.31:
 | 
			
		||||
            if base_model_type in "recam_1.3B":
 | 
			
		||||
                video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
			
		||||
                if not "V" in video_prompt_type:
 | 
			
		||||
                    video_prompt_type += "UV"
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
                    ui_defaults["image_prompt_type"] = ""
 | 
			
		||||
 | 
			
		||||
            if text_oneframe_overlap(base_model_type):
 | 
			
		||||
                ui_defaults["sliding_window_overlap"] = 1
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.32:
 | 
			
		||||
            image_prompt_type = ui_defaults.get("image_prompt_type", "")
 | 
			
		||||
            if test_class_i2v(base_model_type) and len(image_prompt_type) == 0:
 | 
			
		||||
                ui_defaults["image_prompt_type"] = "S" 
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
            "sample_solver": "unipc",
 | 
			
		||||
        })
 | 
			
		||||
        if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
 | 
			
		||||
            ui_defaults["image_prompt_type"] = "S" 
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["fantasy"]:
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "audio_guidance_scale": 5.0,
 | 
			
		||||
                "sliding_window_size": 1, 
 | 
			
		||||
                "sliding_window_overlap" : 1,
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        elif base_model_type in ["multitalk"]:
 | 
			
		||||
@ -260,6 +457,7 @@ class family_handler():
 | 
			
		||||
                "guidance_scale": 5.0,
 | 
			
		||||
                "flow_shift": 7, # 11 for 720p
 | 
			
		||||
                "sliding_window_overlap" : 9,
 | 
			
		||||
                "sliding_window_size": 81, 
 | 
			
		||||
                "sample_solver" : "euler",
 | 
			
		||||
                "video_prompt_type": "QKI",
 | 
			
		||||
                "remove_background_images_ref" : 0,
 | 
			
		||||
@ -293,6 +491,21 @@ class family_handler():
 | 
			
		||||
                "image_prompt_type": "T", 
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["recam_1.3B", "lucy_edit"]: 
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "UV", 
 | 
			
		||||
            })
 | 
			
		||||
        elif base_model_type in ["animate"]: 
 | 
			
		||||
            ui_defaults.update({ 
 | 
			
		||||
                "video_prompt_type": "PVBKI", 
 | 
			
		||||
                "mask_expand": 20,
 | 
			
		||||
                "audio_prompt_type": "R",
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        if text_oneframe_overlap(base_model_type):
 | 
			
		||||
            ui_defaults["sliding_window_overlap"] = 1
 | 
			
		||||
            ui_defaults["color_correction_strength"]= 0
 | 
			
		||||
 | 
			
		||||
        if test_multitalk(base_model_type):
 | 
			
		||||
            ui_defaults["audio_guidance_scale"] = 4
 | 
			
		||||
 | 
			
		||||
@ -309,3 +522,11 @@ class family_handler():
 | 
			
		||||
            if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None:
 | 
			
		||||
                video_prompt_type = video_prompt_type.replace("I", "").replace("K","")
 | 
			
		||||
                inputs["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["vace_standin_14B"]:
 | 
			
		||||
            image_refs = inputs["image_refs"]
 | 
			
		||||
            video_prompt_type = inputs["video_prompt_type"]
 | 
			
		||||
            if image_refs is not None and  len(image_refs) == 1 and "K" in video_prompt_type:
 | 
			
		||||
                gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.")
 | 
			
		||||
                    
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										69
									
								
								preprocessing/extract_vocals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								preprocessing/extract_vocals.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import os, tempfile
 | 
			
		||||
import numpy as np
 | 
			
		||||
import soundfile as sf
 | 
			
		||||
import librosa
 | 
			
		||||
import torch
 | 
			
		||||
import gc
 | 
			
		||||
 | 
			
		||||
from audio_separator.separator import Separator
 | 
			
		||||
 | 
			
		||||
def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    If the source audio is shorter than `min_seconds`, pad with trailing silence
 | 
			
		||||
    in a temporary file, then run separation and save only the vocals to dst_path.
 | 
			
		||||
    Returns the full path to the vocals file.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    default_device = torch.get_default_device()
 | 
			
		||||
    torch.set_default_device('cpu')
 | 
			
		||||
 | 
			
		||||
    dst = Path(dst_path)
 | 
			
		||||
    dst.parent.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # Quick duration check
 | 
			
		||||
    duration = librosa.get_duration(path=src_path)
 | 
			
		||||
 | 
			
		||||
    use_path = src_path
 | 
			
		||||
    temp_path = None
 | 
			
		||||
    try:
 | 
			
		||||
        if duration < min_seconds:
 | 
			
		||||
            # Load (resample) and pad in memory
 | 
			
		||||
            y, sr = librosa.load(src_path, sr=None, mono=False)
 | 
			
		||||
            if y.ndim == 1:  # ensure shape (channels, samples)
 | 
			
		||||
                y = y[np.newaxis, :]
 | 
			
		||||
            target_len = int(min_seconds * sr)
 | 
			
		||||
            pad = max(0, target_len - y.shape[1])
 | 
			
		||||
            if pad:
 | 
			
		||||
                y = np.pad(y, ((0, 0), (0, pad)), mode="constant")
 | 
			
		||||
 | 
			
		||||
            # Write a temp WAV for the separator
 | 
			
		||||
            fd, temp_path = tempfile.mkstemp(suffix=".wav")
 | 
			
		||||
            os.close(fd)
 | 
			
		||||
            sf.write(temp_path, y.T, sr)  # soundfile expects (frames, channels)
 | 
			
		||||
            use_path = temp_path
 | 
			
		||||
 | 
			
		||||
        # Run separation: emit only the vocals, with your exact filename
 | 
			
		||||
        sep = Separator(
 | 
			
		||||
            output_dir=str(dst.parent),
 | 
			
		||||
            output_format=(dst.suffix.lstrip(".") or "wav"),
 | 
			
		||||
            output_single_stem="Vocals",
 | 
			
		||||
            model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt"
 | 
			
		||||
        )
 | 
			
		||||
        sep.load_model()
 | 
			
		||||
        out_files = sep.separate(use_path, {"Vocals": dst.stem})
 | 
			
		||||
 | 
			
		||||
        out = Path(out_files[0])
 | 
			
		||||
        return str(out if out.is_absolute() else (dst.parent / out))
 | 
			
		||||
    finally:
 | 
			
		||||
        if temp_path and os.path.exists(temp_path):
 | 
			
		||||
            os.remove(temp_path)
 | 
			
		||||
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        torch.set_default_device(default_device)
 | 
			
		||||
 | 
			
		||||
# Example:
 | 
			
		||||
# final = extract_vocals("in/clip.mp3", "out/vocals.wav")
 | 
			
		||||
# print(final)
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@ import psutil
 | 
			
		||||
# import ffmpeg
 | 
			
		||||
import imageio
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
import cv2
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
@ -22,6 +21,7 @@ from .utils.get_default_model import get_matanyone_model
 | 
			
		||||
from .matanyone.inference.inference_core import InferenceCore
 | 
			
		||||
from .matanyone_wrapper import matanyone
 | 
			
		||||
from shared.utils.audio_video import save_video, save_image
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
 | 
			
		||||
arg_device = "cuda"
 | 
			
		||||
arg_sam_model_type="vit_h"
 | 
			
		||||
@ -33,6 +33,8 @@ model_in_GPU = False
 | 
			
		||||
matanyone_in_GPU = False
 | 
			
		||||
bfloat16_supported = False
 | 
			
		||||
# SAM generator
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
class MaskGenerator():
 | 
			
		||||
    def __init__(self, sam_checkpoint, device):
 | 
			
		||||
        global args_device
 | 
			
		||||
@ -89,6 +91,7 @@ def get_frames_from_image(image_input, image_state):
 | 
			
		||||
        "last_frame_numer": 0,
 | 
			
		||||
        "fps": None
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
    image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
 | 
			
		||||
    set_image_encoder_patch()
 | 
			
		||||
    select_SAM()
 | 
			
		||||
@ -537,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive
 | 
			
		||||
    file_name = ".".join(file_name.split(".")[:-1]) 
 | 
			
		||||
 
 | 
			
		||||
    from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files    
 | 
			
		||||
    source_audio_tracks, audio_metadata  = extract_audio_tracks(video_input)
 | 
			
		||||
    source_audio_tracks, audio_metadata  = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel )
 | 
			
		||||
    output_fg_path =  f"./mask_outputs/{file_name}_fg.mp4"
 | 
			
		||||
    output_fg_temp_path =  f"./mask_outputs/{file_name}_fg_tmp.mp4"
 | 
			
		||||
    if len(source_audio_tracks) == 0:
 | 
			
		||||
@ -677,7 +680,6 @@ def load_unload_models(selected):
 | 
			
		||||
            }
 | 
			
		||||
            # os.path.join('.')
 | 
			
		||||
 | 
			
		||||
            from mmgp import offload
 | 
			
		||||
 | 
			
		||||
            # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
 | 
			
		||||
            sam_checkpoint = None
 | 
			
		||||
@ -695,7 +697,8 @@ def load_unload_models(selected):
 | 
			
		||||
                model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device)
 | 
			
		||||
                model_in_GPU = True
 | 
			
		||||
                from .matanyone.model.matanyone import MatAnyone
 | 
			
		||||
                matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
 | 
			
		||||
                # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
 | 
			
		||||
                matanyone_model = MatAnyone.from_pretrained("ckpts/mask")
 | 
			
		||||
                # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
 | 
			
		||||
                # offload.profile(pipe)
 | 
			
		||||
                matanyone_model = matanyone_model.to("cpu").eval()
 | 
			
		||||
@ -717,27 +720,33 @@ def load_unload_models(selected):
 | 
			
		||||
def get_vmc_event_handler():
 | 
			
		||||
    return load_unload_models
 | 
			
		||||
 | 
			
		||||
def export_to_vace_video_input(foreground_video_output):
 | 
			
		||||
    gr.Info("Masked Video Input transferred to Vace For Inpainting")
 | 
			
		||||
    return "V#" + str(time.time()), foreground_video_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_image(image_refs, image_output):
 | 
			
		||||
    gr.Info("Masked Image transferred to Current Video")
 | 
			
		||||
def export_image(state, image_output):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    image_refs = ui_settings["image_refs"]
 | 
			
		||||
    if image_refs == None:
 | 
			
		||||
        image_refs =[]
 | 
			
		||||
    image_refs.append( image_output)
 | 
			
		||||
    return image_refs
 | 
			
		||||
    ui_settings["image_refs"] = image_refs 
 | 
			
		||||
    gr.Info("Masked Image transferred to Current Image Generator")
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
def export_image_mask(image_input, image_mask):
 | 
			
		||||
    gr.Info("Input Image & Mask transferred to Current Video")
 | 
			
		||||
    return Image.fromarray(image_input), image_mask
 | 
			
		||||
def export_image_mask(state, image_input, image_mask):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    ui_settings["image_guide"] = Image.fromarray(image_input)
 | 
			
		||||
    ui_settings["image_mask"] = image_mask
 | 
			
		||||
 | 
			
		||||
    gr.Info("Input Image & Mask transferred to Current Image Generator")
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_to_current_video_engine( foreground_video_output, alpha_video_output):
 | 
			
		||||
def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    ui_settings["video_guide"] = foreground_video_output
 | 
			
		||||
    ui_settings["video_mask"] = alpha_video_output
 | 
			
		||||
 | 
			
		||||
    gr.Info("Original Video and Full Mask have been transferred")
 | 
			
		||||
    # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
 | 
			
		||||
    return foreground_video_output, alpha_video_output
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def teleport_to_video_tab(tab_state):
 | 
			
		||||
@ -746,15 +755,29 @@ def teleport_to_video_tab(tab_state):
 | 
			
		||||
    return gr.Tabs(selected="video_gen")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
 | 
			
		||||
def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #,  vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
 | 
			
		||||
    # my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
 | 
			
		||||
    global image_output_codec, video_output_codec
 | 
			
		||||
    global image_output_codec, video_output_codec, get_current_model_settings
 | 
			
		||||
    get_current_model_settings = get_current_model_settings_fn
 | 
			
		||||
 | 
			
		||||
    image_output_codec = server_config.get("image_output_codec", None)
 | 
			
		||||
    video_output_codec = server_config.get("video_output_codec", None)
 | 
			
		||||
 | 
			
		||||
    media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
 | 
			
		||||
 | 
			
		||||
    click_brush_js = """
 | 
			
		||||
    () => {
 | 
			
		||||
        setTimeout(() => {
 | 
			
		||||
            const brushButton = document.querySelector('button[aria-label="Brush"]');
 | 
			
		||||
            if (brushButton) {
 | 
			
		||||
                brushButton.click();
 | 
			
		||||
                console.log('Brush button clicked');
 | 
			
		||||
            } else {
 | 
			
		||||
                console.log('Brush button not found');
 | 
			
		||||
            }
 | 
			
		||||
        }, 1000);
 | 
			
		||||
    }    """
 | 
			
		||||
 | 
			
		||||
    # download assets
 | 
			
		||||
 | 
			
		||||
    gr.Markdown("<B>Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep</B>")
 | 
			
		||||
@ -871,7 +894,7 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                            template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False,  min_width=100)
 | 
			
		||||
                                add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
 | 
			
		||||
                                add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
 | 
			
		||||
                                remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False,  min_width=100) # no use
 | 
			
		||||
                                matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False,  min_width=100)
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
@ -892,7 +915,7 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                            with gr.Row(visible= True):
 | 
			
		||||
                                export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
 | 
			
		||||
                                    
 | 
			
		||||
                export_to_current_video_engine_btn.click(  fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                export_to_current_video_engine_btn.click(  fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1089,10 +1112,10 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                    # with gr.Column(scale=2, visible= True):
 | 
			
		||||
                        export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
 | 
			
		||||
 | 
			
		||||
                export_image_btn.click(  fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
                export_image_mask_btn.click(  fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                export_image_btn.click(  fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
                export_image_mask_btn.click(  fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]).then(fn=None, inputs=None, outputs=None, js=click_brush_js)
 | 
			
		||||
 | 
			
		||||
                # first step: get the image information 
 | 
			
		||||
                extract_frames_button.click(
 | 
			
		||||
@ -1148,5 +1171,21 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                    outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                nada = gr.State({})
 | 
			
		||||
                # clear input
 | 
			
		||||
                gr.on(
 | 
			
		||||
                    triggers=[image_input.clear], #image_input.change,
 | 
			
		||||
                    fn=restart,
 | 
			
		||||
                    inputs=[],
 | 
			
		||||
                    outputs=[ 
 | 
			
		||||
                        image_state,
 | 
			
		||||
                        interactive_state,
 | 
			
		||||
                        click_state,
 | 
			
		||||
                        foreground_image_output, alpha_image_output,
 | 
			
		||||
                        template_frame,
 | 
			
		||||
                        image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click, 
 | 
			
		||||
                        add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
 | 
			
		||||
                    ],
 | 
			
		||||
                    queue=False,
 | 
			
		||||
                    show_progress=False)
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@ import math
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional, Union, Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def get_similarity(mk: torch.Tensor,
 | 
			
		||||
                   ms: torch.Tensor,
 | 
			
		||||
@ -59,6 +58,7 @@ def get_similarity(mk: torch.Tensor,
 | 
			
		||||
        del two_ab 
 | 
			
		||||
        # similarity = (-a_sq + two_ab)
 | 
			
		||||
 | 
			
		||||
    similarity =similarity.float()
 | 
			
		||||
    if ms is not None:
 | 
			
		||||
        similarity *= ms
 | 
			
		||||
        similarity /=  math.sqrt(CK)
 | 
			
		||||
 | 
			
		||||
@ -73,5 +73,5 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
 | 
			
		||||
        if ti > (n_warmup-1):
 | 
			
		||||
            frames.append((com_np*255).astype(np.uint8))
 | 
			
		||||
            phas.append((pha*255).astype(np.uint8))
 | 
			
		||||
    
 | 
			
		||||
            # phas.append(np.clip(pha * 255, 0, 255).astype(np.uint8))    
 | 
			
		||||
    return frames, phas
 | 
			
		||||
@ -100,7 +100,7 @@ class OptimizedPyannote31SpeakerSeparator:
 | 
			
		||||
        self.hf_token = hf_token
 | 
			
		||||
        self._overlap_pipeline = None
 | 
			
		||||
 | 
			
		||||
    def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]:
 | 
			
		||||
    def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None  ) -> Dict[str, str]:
 | 
			
		||||
        """Optimized main separation function with memory management."""
 | 
			
		||||
        xprint("Starting optimized audio separation...")
 | 
			
		||||
        self._current_audio_path = os.path.abspath(audio_path)        
 | 
			
		||||
@ -128,7 +128,11 @@ class OptimizedPyannote31SpeakerSeparator:
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        
 | 
			
		||||
        # Save outputs efficiently
 | 
			
		||||
        output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2)
 | 
			
		||||
        if audio_original_path is None:
 | 
			
		||||
            waveform_original = waveform
 | 
			
		||||
        else:
 | 
			
		||||
            waveform_original, sample_rate = self.load_audio(audio_original_path)
 | 
			
		||||
        output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2)
 | 
			
		||||
        
 | 
			
		||||
        return output_paths
 | 
			
		||||
 | 
			
		||||
@ -835,7 +839,7 @@ class OptimizedPyannote31SpeakerSeparator:
 | 
			
		||||
        for turn, _, speaker in diarization.itertracks(yield_label=True):
 | 
			
		||||
            xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s")
 | 
			
		||||
 | 
			
		||||
def extract_dual_audio(audio, output1, output2, verbose = False):
 | 
			
		||||
def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None):
 | 
			
		||||
    global verbose_output
 | 
			
		||||
    verbose_output = verbose
 | 
			
		||||
    separator = OptimizedPyannote31SpeakerSeparator(
 | 
			
		||||
@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False):
 | 
			
		||||
    import time
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    
 | 
			
		||||
    outputs = separator.separate_audio(audio, output1, output2)
 | 
			
		||||
    outputs = separator.separate_audio(audio, output1, output2, audio_original)
 | 
			
		||||
    
 | 
			
		||||
    elapsed_time = time.time() - start_time
 | 
			
		||||
    xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
 | 
			
		||||
 | 
			
		||||
@ -20,14 +20,16 @@ soundfile
 | 
			
		||||
mutagen
 | 
			
		||||
pyloudnorm
 | 
			
		||||
librosa==0.11.0
 | 
			
		||||
speechbrain==1.0.3
 | 
			
		||||
audio-separator==0.36.1
 | 
			
		||||
 | 
			
		||||
# UI & interaction
 | 
			
		||||
gradio==5.23.0
 | 
			
		||||
gradio==5.29.0
 | 
			
		||||
dashscope
 | 
			
		||||
loguru
 | 
			
		||||
 | 
			
		||||
# Vision & segmentation
 | 
			
		||||
opencv-python>=4.9.0.80
 | 
			
		||||
opencv-python>=4.12.0.88
 | 
			
		||||
segment-anything
 | 
			
		||||
rembg[gpu]==2.0.65
 | 
			
		||||
onnxruntime-gpu
 | 
			
		||||
@ -43,14 +45,14 @@ pydantic==2.10.6
 | 
			
		||||
# Math & modeling
 | 
			
		||||
torchdiffeq>=0.2.5
 | 
			
		||||
tensordict>=0.6.1
 | 
			
		||||
mmgp==3.5.10
 | 
			
		||||
peft==0.17.0
 | 
			
		||||
mmgp==3.6.0
 | 
			
		||||
peft==0.15.0
 | 
			
		||||
matplotlib
 | 
			
		||||
 | 
			
		||||
# Utilities
 | 
			
		||||
ftfy
 | 
			
		||||
piexif
 | 
			
		||||
pynvml
 | 
			
		||||
nvidia-ml-py 
 | 
			
		||||
misaki
 | 
			
		||||
 | 
			
		||||
# Optional / commented out
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py)
 | 
			
		||||
def get_rgb_factors(model_family, model_type = None): 
 | 
			
		||||
    if model_family == "wan":
 | 
			
		||||
    if model_family in ["wan", "qwen"]:
 | 
			
		||||
        if model_type =="ti2v_2_2":
 | 
			
		||||
            latent_channels = 48
 | 
			
		||||
            latent_dimensions = 3            
 | 
			
		||||
@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None):
 | 
			
		||||
            [ 0.0249, -0.0469, -0.1703]
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]        
 | 
			
		||||
        latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
 | 
			
		||||
    else:
 | 
			
		||||
        latent_rgb_factors_bias = latent_rgb_factors = None
 | 
			
		||||
    return latent_rgb_factors, latent_rgb_factors_bias
 | 
			
		||||
@ -3,6 +3,7 @@ import torch
 | 
			
		||||
from importlib.metadata import version
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
major, minor = torch.cuda.get_device_capability(None)
 | 
			
		||||
bfloat16_supported =  major >= 8 
 | 
			
		||||
@ -42,34 +43,51 @@ except ImportError:
 | 
			
		||||
    sageattn_varlen_wrapper = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from sageattention import sageattn
 | 
			
		||||
    from .sage2_core import sageattn as alt_sageattn, is_sage2_supported
 | 
			
		||||
    from .sage2_core import sageattn as sageattn2, is_sage2_supported
 | 
			
		||||
    sage2_supported =  is_sage2_supported()
 | 
			
		||||
except ImportError:
 | 
			
		||||
    sageattn = None
 | 
			
		||||
    alt_sageattn = None
 | 
			
		||||
    sageattn2 = None
 | 
			
		||||
    sage2_supported = False
 | 
			
		||||
# @torch.compiler.disable()
 | 
			
		||||
def sageattn_wrapper(
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sageattn2_wrapper(
 | 
			
		||||
        qkv_list,
 | 
			
		||||
        attention_length
 | 
			
		||||
    ):
 | 
			
		||||
    q,k, v = qkv_list
 | 
			
		||||
    if True:
 | 
			
		||||
        qkv_list = [q,k,v]
 | 
			
		||||
        del q, k ,v
 | 
			
		||||
        o = alt_sageattn(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    else:
 | 
			
		||||
        o = sageattn(q, k, v, tensor_layout="NHD")
 | 
			
		||||
        del q, k ,v
 | 
			
		||||
 | 
			
		||||
    qkv_list = [q,k,v]
 | 
			
		||||
    del q, k ,v
 | 
			
		||||
    o = sageattn2(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    qkv_list.clear()
 | 
			
		||||
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from sageattn import sageattn_blackwell as sageattn3
 | 
			
		||||
except ImportError:
 | 
			
		||||
    sageattn3 = None
 | 
			
		||||
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sageattn3_wrapper(
 | 
			
		||||
        qkv_list,
 | 
			
		||||
        attention_length
 | 
			
		||||
    ):
 | 
			
		||||
    q,k, v = qkv_list
 | 
			
		||||
    # qkv_list = [q,k,v]
 | 
			
		||||
    # del q, k ,v
 | 
			
		||||
    # o = sageattn3(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    q = q.transpose(1,2)
 | 
			
		||||
    k = k.transpose(1,2)
 | 
			
		||||
    v = v.transpose(1,2)
 | 
			
		||||
    o = sageattn3(q, k, v)
 | 
			
		||||
    o = o.transpose(1,2)
 | 
			
		||||
    qkv_list.clear()
 | 
			
		||||
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
     
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# try:
 | 
			
		||||
# if True:
 | 
			
		||||
    # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
@ -94,7 +112,7 @@ def sageattn_wrapper(
 | 
			
		||||
 | 
			
		||||
    #     return o
 | 
			
		||||
# except ImportError:
 | 
			
		||||
#     sageattn = sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
#     sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sdpa_wrapper(
 | 
			
		||||
@ -124,21 +142,28 @@ def get_attention_modes():
 | 
			
		||||
        ret.append("xformers")
 | 
			
		||||
    if sageattn_varlen_wrapper != None:
 | 
			
		||||
        ret.append("sage")
 | 
			
		||||
    if sageattn != None and version("sageattention").startswith("2") :
 | 
			
		||||
    if sageattn2 != None and version("sageattention").startswith("2") :
 | 
			
		||||
        ret.append("sage2")
 | 
			
		||||
    if sageattn3 != None: # and version("sageattention").startswith("3") :
 | 
			
		||||
        ret.append("sage3")
 | 
			
		||||
        
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
def get_supported_attention_modes():
 | 
			
		||||
    ret = get_attention_modes()
 | 
			
		||||
    major, minor = torch.cuda.get_device_capability()
 | 
			
		||||
    if  major < 10:
 | 
			
		||||
        if "sage3" in ret:
 | 
			
		||||
            ret.remove("sage3")
 | 
			
		||||
 | 
			
		||||
    if not sage2_supported:
 | 
			
		||||
        if "sage2" in ret:
 | 
			
		||||
            ret.remove("sage2")
 | 
			
		||||
 | 
			
		||||
    major, minor = torch.cuda.get_device_capability()
 | 
			
		||||
    if  major < 7:
 | 
			
		||||
        if "sage" in ret:
 | 
			
		||||
            ret.remove("sage")
 | 
			
		||||
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
@ -201,7 +226,7 @@ def pay_attention(
 | 
			
		||||
        from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
        from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
			
		||||
 | 
			
		||||
    if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
 | 
			
		||||
    if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
 | 
			
		||||
        assert attention_mask == None
 | 
			
		||||
        # Poor's man var k len attention
 | 
			
		||||
        assert q_lens == None
 | 
			
		||||
@ -234,7 +259,7 @@ def pay_attention(
 | 
			
		||||
            q_chunks, k_chunks, v_chunks = None, None, None
 | 
			
		||||
            o = torch.cat(o, dim = 0)
 | 
			
		||||
            return o
 | 
			
		||||
    elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"):
 | 
			
		||||
    elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"):
 | 
			
		||||
        assert b == 1
 | 
			
		||||
        szq = q_lens[0].item() if q_lens != None else lq
 | 
			
		||||
        szk = k_lens[0].item() if k_lens != None else lk
 | 
			
		||||
@ -284,13 +309,19 @@ def pay_attention(
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_kv=lk,
 | 
			
		||||
        ).unflatten(0, (b, lq))
 | 
			
		||||
    elif attn=="sage3":
 | 
			
		||||
        import math
 | 
			
		||||
        if cross_attn or True:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
            x = sageattn3_wrapper(qkv_list, lq)
 | 
			
		||||
    elif attn=="sage2":
 | 
			
		||||
        import math
 | 
			
		||||
        if cross_attn or True:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
 | 
			
		||||
            x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
 | 
			
		||||
            x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     layer =  offload.shared_state["layer"]
 | 
			
		||||
        #     embed_sizes = offload.shared_state["embed_sizes"] 
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										342
									
								
								shared/convert/convert_diffusers_to_flux.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										342
									
								
								shared/convert/convert_diffusers_to_flux.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,342 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
"""
 | 
			
		||||
Convert a Flux model from Diffusers (folder or single-file) into the original
 | 
			
		||||
single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
 | 
			
		||||
 | 
			
		||||
Input  : /path/to/diffusers   (root or .../transformer)  OR  /path/to/*.safetensors (single file)
 | 
			
		||||
Output : /path/to/flux1-your-model.safetensors  (transformer only)
 | 
			
		||||
 | 
			
		||||
Usage:
 | 
			
		||||
  python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
 | 
			
		||||
  python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
 | 
			
		||||
  # optional quantization:
 | 
			
		||||
  #   --fp8           (float8_e4m3fn, simple)
 | 
			
		||||
  #   --fp8-scaled    (scaled float8 for 2D weights; adds .scale_weight tensors)
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
import safetensors.torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_args():
 | 
			
		||||
    ap = argparse.ArgumentParser()
 | 
			
		||||
    ap.add_argument("diffusers_path", type=str,
 | 
			
		||||
                    help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
 | 
			
		||||
    ap.add_argument("output_path", type=str,
 | 
			
		||||
                    help="Output .safetensors path for the Flux transformer.")
 | 
			
		||||
    ap.add_argument("--fp8", action="store_true",
 | 
			
		||||
                    help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
 | 
			
		||||
    ap.add_argument("--fp8-scaled", action="store_true",
 | 
			
		||||
                    help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
 | 
			
		||||
    return ap.parse_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
 | 
			
		||||
DIFFUSERS_MAP = {
 | 
			
		||||
    # global embeds
 | 
			
		||||
    "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
 | 
			
		||||
    "time_in.in_layer.bias":   ["time_text_embed.timestep_embedder.linear_1.bias"],
 | 
			
		||||
    "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
 | 
			
		||||
    "time_in.out_layer.bias":   ["time_text_embed.timestep_embedder.linear_2.bias"],
 | 
			
		||||
 | 
			
		||||
    "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
 | 
			
		||||
    "vector_in.in_layer.bias":   ["time_text_embed.text_embedder.linear_1.bias"],
 | 
			
		||||
    "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
 | 
			
		||||
    "vector_in.out_layer.bias":   ["time_text_embed.text_embedder.linear_2.bias"],
 | 
			
		||||
 | 
			
		||||
    "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
 | 
			
		||||
    "guidance_in.in_layer.bias":   ["time_text_embed.guidance_embedder.linear_1.bias"],
 | 
			
		||||
    "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
 | 
			
		||||
    "guidance_in.out_layer.bias":   ["time_text_embed.guidance_embedder.linear_2.bias"],
 | 
			
		||||
 | 
			
		||||
    "txt_in.weight": ["context_embedder.weight"],
 | 
			
		||||
    "txt_in.bias":   ["context_embedder.bias"],
 | 
			
		||||
    "img_in.weight": ["x_embedder.weight"],
 | 
			
		||||
    "img_in.bias":   ["x_embedder.bias"],
 | 
			
		||||
 | 
			
		||||
    # dual-stream (image/text) blocks
 | 
			
		||||
    "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
 | 
			
		||||
    "double_blocks.().img_mod.lin.bias":   ["norm1.linear.bias"],
 | 
			
		||||
    "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
 | 
			
		||||
    "double_blocks.().txt_mod.lin.bias":   ["norm1_context.linear.bias"],
 | 
			
		||||
 | 
			
		||||
    "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
 | 
			
		||||
    "double_blocks.().img_attn.qkv.bias":   ["attn.to_q.bias",   "attn.to_k.bias",   "attn.to_v.bias"],
 | 
			
		||||
    "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
 | 
			
		||||
    "double_blocks.().txt_attn.qkv.bias":   ["attn.add_q_proj.bias",   "attn.add_k_proj.bias",   "attn.add_v_proj.bias"],
 | 
			
		||||
 | 
			
		||||
    "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
 | 
			
		||||
    "double_blocks.().img_attn.norm.key_norm.scale":   ["attn.norm_k.weight"],
 | 
			
		||||
    "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
 | 
			
		||||
    "double_blocks.().txt_attn.norm.key_norm.scale":   ["attn.norm_added_k.weight"],
 | 
			
		||||
 | 
			
		||||
    "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
 | 
			
		||||
    "double_blocks.().img_mlp.0.bias":   ["ff.net.0.proj.bias"],
 | 
			
		||||
    "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
 | 
			
		||||
    "double_blocks.().img_mlp.2.bias":   ["ff.net.2.bias"],
 | 
			
		||||
 | 
			
		||||
    "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
 | 
			
		||||
    "double_blocks.().txt_mlp.0.bias":   ["ff_context.net.0.proj.bias"],
 | 
			
		||||
    "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
 | 
			
		||||
    "double_blocks.().txt_mlp.2.bias":   ["ff_context.net.2.bias"],
 | 
			
		||||
 | 
			
		||||
    "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
 | 
			
		||||
    "double_blocks.().img_attn.proj.bias":   ["attn.to_out.0.bias"],
 | 
			
		||||
    "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
 | 
			
		||||
    "double_blocks.().txt_attn.proj.bias":   ["attn.to_add_out.bias"],
 | 
			
		||||
 | 
			
		||||
    # single-stream blocks
 | 
			
		||||
    "single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
 | 
			
		||||
    "single_blocks.().modulation.lin.bias":   ["norm.linear.bias"],
 | 
			
		||||
    "single_blocks.().linear1.weight":        ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
 | 
			
		||||
    "single_blocks.().linear1.bias":          ["attn.to_q.bias",   "attn.to_k.bias",   "attn.to_v.bias",   "proj_mlp.bias"],
 | 
			
		||||
    "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
 | 
			
		||||
    "single_blocks.().norm.key_norm.scale":   ["attn.norm_k.weight"],
 | 
			
		||||
    "single_blocks.().linear2.weight":        ["proj_out.weight"],
 | 
			
		||||
    "single_blocks.().linear2.bias":          ["proj_out.bias"],
 | 
			
		||||
 | 
			
		||||
    # final
 | 
			
		||||
    "final_layer.linear.weight":              ["proj_out.weight"],
 | 
			
		||||
    "final_layer.linear.bias":                ["proj_out.bias"],
 | 
			
		||||
    # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
 | 
			
		||||
    "final_layer.adaLN_modulation.1.weight":  ["norm_out.linear.weight"],
 | 
			
		||||
    "final_layer.adaLN_modulation.1.bias":    ["norm_out.linear.bias"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DiffusersSource:
 | 
			
		||||
    """
 | 
			
		||||
    Uniform interface over:
 | 
			
		||||
      1) Folder with index JSON + shards
 | 
			
		||||
      2) Folder with exactly one .safetensors (no index)
 | 
			
		||||
      3) Single .safetensors file
 | 
			
		||||
    Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    POSSIBLE_PREFIXES = ["", "model."]  # try in this order
 | 
			
		||||
 | 
			
		||||
    def __init__(self, path: Path):
 | 
			
		||||
        p = Path(path)
 | 
			
		||||
        if p.is_dir():
 | 
			
		||||
            # use 'transformer' subfolder if present
 | 
			
		||||
            if (p / "transformer").is_dir():
 | 
			
		||||
                p = p / "transformer"
 | 
			
		||||
            self._init_from_dir(p)
 | 
			
		||||
        elif p.is_file() and p.suffix == ".safetensors":
 | 
			
		||||
            self._init_from_single_file(p)
 | 
			
		||||
        else:
 | 
			
		||||
            raise FileNotFoundError(f"Invalid path: {p}")
 | 
			
		||||
 | 
			
		||||
    # ---------- common helpers ----------
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _strip_prefix(k: str) -> str:
 | 
			
		||||
        return k[6:] if k.startswith("model.") else k
 | 
			
		||||
 | 
			
		||||
    def _resolve(self, want: str):
 | 
			
		||||
        """
 | 
			
		||||
        Return the actual stored key matching `want` by trying known prefixes.
 | 
			
		||||
        """
 | 
			
		||||
        for pref in self.POSSIBLE_PREFIXES:
 | 
			
		||||
            k = pref + want
 | 
			
		||||
            if k in self._all_keys:
 | 
			
		||||
                return k
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def has(self, want: str) -> bool:
 | 
			
		||||
        return self._resolve(want) is not None
 | 
			
		||||
 | 
			
		||||
    def get(self, want: str) -> torch.Tensor:
 | 
			
		||||
        real_key = self._resolve(want)
 | 
			
		||||
        if real_key is None:
 | 
			
		||||
            raise KeyError(f"Missing key: {want}")
 | 
			
		||||
        return self._get_by_real_key(real_key).to("cpu")
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def base_keys(self):
 | 
			
		||||
        # keys without 'model.' prefix for scanning
 | 
			
		||||
        return [self._strip_prefix(k) for k in self._all_keys]
 | 
			
		||||
 | 
			
		||||
    # ---------- modes ----------
 | 
			
		||||
 | 
			
		||||
    def _init_from_single_file(self, file_path: Path):
 | 
			
		||||
        self._mode = "single"
 | 
			
		||||
        self._file = file_path
 | 
			
		||||
        self._handle = safe_open(file_path, framework="pt", device="cpu")
 | 
			
		||||
        self._all_keys = list(self._handle.keys())
 | 
			
		||||
 | 
			
		||||
        def _get_by_real_key(real_key: str):
 | 
			
		||||
            return self._handle.get_tensor(real_key)
 | 
			
		||||
 | 
			
		||||
        self._get_by_real_key = _get_by_real_key
 | 
			
		||||
 | 
			
		||||
    def _init_from_dir(self, dpath: Path):
 | 
			
		||||
        index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
 | 
			
		||||
        if index_json.exists():
 | 
			
		||||
            with open(index_json, "r", encoding="utf-8") as f:
 | 
			
		||||
                index = json.load(f)
 | 
			
		||||
            weight_map = index["weight_map"]  # full mapping
 | 
			
		||||
            self._mode = "sharded"
 | 
			
		||||
            self._dpath = dpath
 | 
			
		||||
            self._weight_map = {k: dpath / v for k, v in weight_map.items()}
 | 
			
		||||
            self._all_keys = list(self._weight_map.keys())
 | 
			
		||||
            self._open_handles = {}
 | 
			
		||||
 | 
			
		||||
            def _get_by_real_key(real_key: str):
 | 
			
		||||
                fpath = self._weight_map[real_key]
 | 
			
		||||
                h = self._open_handles.get(fpath)
 | 
			
		||||
                if h is None:
 | 
			
		||||
                    h = safe_open(fpath, framework="pt", device="cpu")
 | 
			
		||||
                    self._open_handles[fpath] = h
 | 
			
		||||
                return h.get_tensor(real_key)
 | 
			
		||||
 | 
			
		||||
            self._get_by_real_key = _get_by_real_key
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # no index: try exactly one safetensors in folder
 | 
			
		||||
        files = sorted(dpath.glob("*.safetensors"))
 | 
			
		||||
        if len(files) != 1:
 | 
			
		||||
            raise FileNotFoundError(
 | 
			
		||||
                f"No index found and {dpath} does not contain exactly one .safetensors file."
 | 
			
		||||
            )
 | 
			
		||||
        self._init_from_single_file(files[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    args = parse_args()
 | 
			
		||||
    src = DiffusersSource(Path(args.diffusers_path))
 | 
			
		||||
 | 
			
		||||
    # Count blocks by scanning base keys (with any 'model.' prefix removed)
 | 
			
		||||
    num_dual = 0
 | 
			
		||||
    num_single = 0
 | 
			
		||||
    for k in src.base_keys:
 | 
			
		||||
        if k.startswith("transformer_blocks."):
 | 
			
		||||
            try:
 | 
			
		||||
                i = int(k.split(".")[1])
 | 
			
		||||
                num_dual = max(num_dual, i + 1)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
        elif k.startswith("single_transformer_blocks."):
 | 
			
		||||
            try:
 | 
			
		||||
                i = int(k.split(".")[1])
 | 
			
		||||
                num_single = max(num_single, i + 1)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
    print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
 | 
			
		||||
 | 
			
		||||
    # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
 | 
			
		||||
    def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        shift, scale = vec.chunk(2, dim=0)
 | 
			
		||||
        return torch.cat([scale, shift], dim=0)
 | 
			
		||||
 | 
			
		||||
    orig = {}
 | 
			
		||||
 | 
			
		||||
    # Per-block (dual)
 | 
			
		||||
    for b in range(num_dual):
 | 
			
		||||
        prefix = f"transformer_blocks.{b}."
 | 
			
		||||
        for okey, dvals in DIFFUSERS_MAP.items():
 | 
			
		||||
            if not okey.startswith("double_blocks."):
 | 
			
		||||
                continue
 | 
			
		||||
            dkeys = [prefix + v for v in dvals]
 | 
			
		||||
            if not all(src.has(k) for k in dkeys):
 | 
			
		||||
                continue
 | 
			
		||||
            if len(dkeys) == 1:
 | 
			
		||||
                orig[okey.replace("()", str(b))] = src.get(dkeys[0])
 | 
			
		||||
            else:
 | 
			
		||||
                orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
 | 
			
		||||
 | 
			
		||||
    # Per-block (single)
 | 
			
		||||
    for b in range(num_single):
 | 
			
		||||
        prefix = f"single_transformer_blocks.{b}."
 | 
			
		||||
        for okey, dvals in DIFFUSERS_MAP.items():
 | 
			
		||||
            if not okey.startswith("single_blocks."):
 | 
			
		||||
                continue
 | 
			
		||||
            dkeys = [prefix + v for v in dvals]
 | 
			
		||||
            if not all(src.has(k) for k in dkeys):
 | 
			
		||||
                continue
 | 
			
		||||
            if len(dkeys) == 1:
 | 
			
		||||
                orig[okey.replace("()", str(b))] = src.get(dkeys[0])
 | 
			
		||||
            else:
 | 
			
		||||
                orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
 | 
			
		||||
 | 
			
		||||
    # Globals (non-block)
 | 
			
		||||
    for okey, dvals in DIFFUSERS_MAP.items():
 | 
			
		||||
        if okey.startswith(("double_blocks.", "single_blocks.")):
 | 
			
		||||
            continue
 | 
			
		||||
        dkeys = dvals
 | 
			
		||||
        if not all(src.has(k) for k in dkeys):
 | 
			
		||||
            continue
 | 
			
		||||
        if len(dkeys) == 1:
 | 
			
		||||
            orig[okey] = src.get(dkeys[0])
 | 
			
		||||
        else:
 | 
			
		||||
            orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
 | 
			
		||||
 | 
			
		||||
    # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
 | 
			
		||||
    if "final_layer.adaLN_modulation.1.weight" in orig:
 | 
			
		||||
        orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
 | 
			
		||||
            orig["final_layer.adaLN_modulation.1.weight"]
 | 
			
		||||
        )
 | 
			
		||||
    if "final_layer.adaLN_modulation.1.bias" in orig:
 | 
			
		||||
        orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
 | 
			
		||||
            orig["final_layer.adaLN_modulation.1.bias"]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Optional FP8 variants (experimental; not required for ComfyUI/BFL)
 | 
			
		||||
    if args.fp8 or args.fp8_scaled:
 | 
			
		||||
        dtype = torch.float8_e4m3fn  # noqa
 | 
			
		||||
        minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
 | 
			
		||||
 | 
			
		||||
        def stochastic_round_to(t):
 | 
			
		||||
            t = t.float().clamp(minv, maxv)
 | 
			
		||||
            lower = torch.floor(t * 256) / 256
 | 
			
		||||
            upper = torch.ceil(t * 256) / 256
 | 
			
		||||
            prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
 | 
			
		||||
            rnd = torch.rand_like(t)
 | 
			
		||||
            out = torch.where(rnd < prob, upper, lower)
 | 
			
		||||
            return out.to(dtype)
 | 
			
		||||
 | 
			
		||||
        def scale_to_8bit(weight, target_max=416.0):
 | 
			
		||||
            absmax = weight.abs().max()
 | 
			
		||||
            scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
 | 
			
		||||
            scaled = (weight / scale).clamp(minv, maxv).to(dtype)
 | 
			
		||||
            return scaled, scale
 | 
			
		||||
 | 
			
		||||
        scales = {}
 | 
			
		||||
        for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
 | 
			
		||||
            t = orig[k]
 | 
			
		||||
            if args.fp8:
 | 
			
		||||
                orig[k] = stochastic_round_to(t)
 | 
			
		||||
            else:
 | 
			
		||||
                if k.endswith(".weight") and t.dim() == 2:
 | 
			
		||||
                    qt, s = scale_to_8bit(t)
 | 
			
		||||
                    orig[k] = qt
 | 
			
		||||
                    scales[k[:-len(".weight")] + ".scale_weight"] = s
 | 
			
		||||
                else:
 | 
			
		||||
                    orig[k] = t.clamp(minv, maxv).to(dtype)
 | 
			
		||||
        if args.fp8_scaled:
 | 
			
		||||
            orig.update(scales)
 | 
			
		||||
            orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
 | 
			
		||||
    else:
 | 
			
		||||
        # Default: save in bfloat16
 | 
			
		||||
        for k in list(orig.keys()):
 | 
			
		||||
            orig[k] = orig[k].to(torch.bfloat16).cpu()
 | 
			
		||||
 | 
			
		||||
    out_path = Path(args.output_path)
 | 
			
		||||
    out_path.parent.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
    meta = OrderedDict()
 | 
			
		||||
    meta["format"] = "pt"
 | 
			
		||||
    meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
 | 
			
		||||
    print(f"Saving transformer to: {out_path}")
 | 
			
		||||
    safetensors.torch.save_file(orig, str(out_path), metadata=meta)
 | 
			
		||||
    print("Done.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										532
									
								
								shared/gradio/gallery.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										532
									
								
								shared/gradio/gallery.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,532 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
import os, io, tempfile, mimetypes
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
 | 
			
		||||
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import PIL
 | 
			
		||||
import time
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
 | 
			
		||||
FilePath = str
 | 
			
		||||
ImageLike = Union["PIL.Image.Image", Any]
 | 
			
		||||
 | 
			
		||||
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
 | 
			
		||||
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
 | 
			
		||||
 | 
			
		||||
def get_state(state):
 | 
			
		||||
    return state if isinstance(state, dict) else state.value
 | 
			
		||||
 | 
			
		||||
def get_list( objs):
 | 
			
		||||
    if objs is None:
 | 
			
		||||
        return []
 | 
			
		||||
    return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
 | 
			
		||||
 | 
			
		||||
def record_last_action(st, last_action):
 | 
			
		||||
    st["last_action"] = last_action
 | 
			
		||||
    st["last_time"] = time.time()
 | 
			
		||||
class AdvancedMediaGallery:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        label: str = "Media",
 | 
			
		||||
        *,
 | 
			
		||||
        media_mode: Literal["image", "video"] = "image",
 | 
			
		||||
        height = None,
 | 
			
		||||
        columns: Union[int, Tuple[int, ...]] = 6,
 | 
			
		||||
        show_label: bool = True,
 | 
			
		||||
        initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None,
 | 
			
		||||
        elem_id: Optional[str] = None,
 | 
			
		||||
        elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",),
 | 
			
		||||
        accept_filter: bool = True,        # restrict Add-button dialog to allowed extensions
 | 
			
		||||
        single_image_mode: bool = False,   # start in single-image mode (Add replaces)
 | 
			
		||||
    ):
 | 
			
		||||
        assert media_mode in ("image", "video")
 | 
			
		||||
        self.label = label
 | 
			
		||||
        self.media_mode = media_mode
 | 
			
		||||
        self.height = height
 | 
			
		||||
        self.columns = columns
 | 
			
		||||
        self.show_label = show_label
 | 
			
		||||
        self.elem_id = elem_id
 | 
			
		||||
        self.elem_classes = list(elem_classes) if elem_classes else None
 | 
			
		||||
        self.accept_filter = accept_filter
 | 
			
		||||
 | 
			
		||||
        items = self._normalize_initial(initial or [], media_mode)
 | 
			
		||||
 | 
			
		||||
        # Components (filled on mount)
 | 
			
		||||
        self.container: Optional[gr.Column] = None
 | 
			
		||||
        self.gallery: Optional[gr.Gallery] = None
 | 
			
		||||
        self.upload_btn: Optional[gr.UploadButton] = None
 | 
			
		||||
        self.btn_remove: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_left: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_right: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_clear: Optional[gr.Button] = None
 | 
			
		||||
 | 
			
		||||
        # Single dict state
 | 
			
		||||
        self.state: Optional[gr.State] = None
 | 
			
		||||
        self._initial_state: Dict[str, Any] = {
 | 
			
		||||
            "items": items,
 | 
			
		||||
            "selected": (len(items) - 1) if items else 0, # None,
 | 
			
		||||
            "single": bool(single_image_mode),
 | 
			
		||||
            "mode": self.media_mode,
 | 
			
		||||
            "last_action": "",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    # ---------------- helpers ----------------
 | 
			
		||||
 | 
			
		||||
    def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]:
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        if mode == "image":
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._ensure_image_item(it)
 | 
			
		||||
                if p is not None:
 | 
			
		||||
                    out.append(p)
 | 
			
		||||
        else:
 | 
			
		||||
            for it in items:
 | 
			
		||||
                if isinstance(item, tuple): item = item[0]
 | 
			
		||||
                if isinstance(it, str) and self._is_video_path(it):
 | 
			
		||||
                    out.append(os.path.abspath(it))
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]:
 | 
			
		||||
        # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path
 | 
			
		||||
        if isinstance(item, tuple): item = item[0]
 | 
			
		||||
        if isinstance(item, str):
 | 
			
		||||
            return os.path.abspath(item) if self._is_image_path(item) else None
 | 
			
		||||
        if PILImage is None:
 | 
			
		||||
            return None
 | 
			
		||||
        try:
 | 
			
		||||
            if isinstance(item, PILImage.Image):
 | 
			
		||||
                img = item
 | 
			
		||||
            else:
 | 
			
		||||
                import numpy as np  # type: ignore
 | 
			
		||||
                if isinstance(item, np.ndarray):
 | 
			
		||||
                    img = PILImage.fromarray(item)
 | 
			
		||||
                elif hasattr(item, "read"):
 | 
			
		||||
                    data = item.read()
 | 
			
		||||
                    img = PILImage.open(io.BytesIO(data)).convert("RGBA")
 | 
			
		||||
                else:
 | 
			
		||||
                    return None
 | 
			
		||||
            tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
 | 
			
		||||
            img.save(tmp.name)
 | 
			
		||||
            return tmp.name
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _extract_path(obj: Any) -> Optional[str]:
 | 
			
		||||
        # Try to get a filesystem path (for mode filtering); otherwise None.
 | 
			
		||||
        if isinstance(obj, str):
 | 
			
		||||
            return obj
 | 
			
		||||
        try:
 | 
			
		||||
            import pathlib
 | 
			
		||||
            if isinstance(obj, pathlib.Path):  # type: ignore
 | 
			
		||||
                return str(obj)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
        if isinstance(obj, dict):
 | 
			
		||||
            return obj.get("path") or obj.get("name")
 | 
			
		||||
        for attr in ("path", "name"):
 | 
			
		||||
            if hasattr(obj, attr):
 | 
			
		||||
                try:
 | 
			
		||||
                    val = getattr(obj, attr)
 | 
			
		||||
                    if isinstance(val, str):
 | 
			
		||||
                        return val
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _is_image_path(p: str) -> bool:
 | 
			
		||||
        ext = os.path.splitext(p)[1].lower()
 | 
			
		||||
        if ext in IMAGE_EXTS:
 | 
			
		||||
            return True
 | 
			
		||||
        mt, _ = mimetypes.guess_type(p)
 | 
			
		||||
        return bool(mt and mt.startswith("image/"))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _is_video_path(p: str) -> bool:
 | 
			
		||||
        ext = os.path.splitext(p)[1].lower()
 | 
			
		||||
        if ext in VIDEO_EXTS:
 | 
			
		||||
            return True
 | 
			
		||||
        mt, _ = mimetypes.guess_type(p)
 | 
			
		||||
        return bool(mt and mt.startswith("video/"))
 | 
			
		||||
 | 
			
		||||
    def _filter_items_by_mode(self, items: List[Any]) -> List[Any]:
 | 
			
		||||
        # Enforce image-only or video-only collection regardless of how files were added.
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        if self.media_mode == "image":
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._extract_path(it)
 | 
			
		||||
                if p is None:
 | 
			
		||||
                    # No path: likely an image object added programmatically => keep
 | 
			
		||||
                    out.append(it)
 | 
			
		||||
                elif self._is_image_path(p):
 | 
			
		||||
                    out.append(os.path.abspath(p))
 | 
			
		||||
        else:
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._extract_path(it)
 | 
			
		||||
                if p is not None and self._is_video_path(p):
 | 
			
		||||
                    out.append(os.path.abspath(p))
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]:
 | 
			
		||||
        # Keep it simple: dedupe by path when available, else allow duplicates.
 | 
			
		||||
        seen_paths = set()
 | 
			
		||||
        def key(x: Any) -> Optional[str]:
 | 
			
		||||
            if isinstance(x, str): return os.path.abspath(x)
 | 
			
		||||
            try:
 | 
			
		||||
                import pathlib
 | 
			
		||||
                if isinstance(x, pathlib.Path):  # type: ignore
 | 
			
		||||
                    return os.path.abspath(str(x))
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
            if isinstance(x, dict):
 | 
			
		||||
                p = x.get("path") or x.get("name")
 | 
			
		||||
                return os.path.abspath(p) if isinstance(p, str) else None
 | 
			
		||||
            for attr in ("path", "name"):
 | 
			
		||||
                if hasattr(x, attr):
 | 
			
		||||
                    try:
 | 
			
		||||
                        v = getattr(x, attr)
 | 
			
		||||
                        return os.path.abspath(v) if isinstance(v, str) else None
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        pass
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        for lst in (cur, add):
 | 
			
		||||
            for it in lst:
 | 
			
		||||
                k = key(it)
 | 
			
		||||
                if k is None or k not in seen_paths:
 | 
			
		||||
                    out.append(it)
 | 
			
		||||
                    if k is not None:
 | 
			
		||||
                        seen_paths.add(k)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _paths_from_payload(payload: Any) -> List[Any]:
 | 
			
		||||
        # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly.
 | 
			
		||||
        if payload is None:
 | 
			
		||||
            return []
 | 
			
		||||
        if isinstance(payload, (list, tuple, set)):
 | 
			
		||||
            return list(payload)
 | 
			
		||||
        return [payload]
 | 
			
		||||
 | 
			
		||||
    # ---------------- event handlers ----------------
 | 
			
		||||
 | 
			
		||||
    def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
 | 
			
		||||
        # Mirror the selected index into state and the gallery (server-side selected_index)
 | 
			
		||||
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        last_time = st.get("last_time", None)
 | 
			
		||||
        if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
 | 
			
		||||
            # print(f"ignored:{time.time()}, real {st['selected']}")
 | 
			
		||||
            return gr.update(selected_index=st["selected"]), st
 | 
			
		||||
 | 
			
		||||
        idx = None
 | 
			
		||||
        if evt is not None and hasattr(evt, "index"):
 | 
			
		||||
            ix = evt.index
 | 
			
		||||
            if isinstance(ix, int):
 | 
			
		||||
                idx = ix
 | 
			
		||||
            elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int):
 | 
			
		||||
                if isinstance(self.columns, int) and len(ix) >= 2:
 | 
			
		||||
                    idx = ix[0] * max(1, int(self.columns)) + ix[1]
 | 
			
		||||
                else:
 | 
			
		||||
                    idx = ix[0]
 | 
			
		||||
        n = len(get_list(gallery))
 | 
			
		||||
        sel = idx if (idx is not None and 0 <= idx < n) else None
 | 
			
		||||
        # print(f"image selected evt index:{sel}/{evt.selected}")
 | 
			
		||||
        st["selected"] = sel
 | 
			
		||||
        return gr.update(), st
 | 
			
		||||
 | 
			
		||||
    def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
 | 
			
		||||
        # Fires when users upload via the Gallery itself.
 | 
			
		||||
        # items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        items_filtered = list(value or [])
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        new_items = self._paths_from_payload(items_filtered)
 | 
			
		||||
        st["items"] = new_items
 | 
			
		||||
        new_sel = len(new_items) - 1
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        record_last_action(st,"add")
 | 
			
		||||
        return gr.update(selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
 | 
			
		||||
        # Fires when users add/drag/drop/delete via the Gallery itself.
 | 
			
		||||
        # items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        items_filtered = list(value or [])
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        st["items"] = items_filtered
 | 
			
		||||
        # Keep selection if still valid, else default to last
 | 
			
		||||
        old_sel = st.get("selected", None)
 | 
			
		||||
        if old_sel is None or not (0 <= old_sel < len(items_filtered)):
 | 
			
		||||
            new_sel = (len(items_filtered) - 1) if items_filtered else None
 | 
			
		||||
        else:
 | 
			
		||||
            new_sel = old_sel
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        st["last_action"] ="gallery_change"
 | 
			
		||||
        # print(f"gallery change: set sel {new_sel}")
 | 
			
		||||
        return gr.update(selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
 | 
			
		||||
        """
 | 
			
		||||
        Insert added items right AFTER the currently selected index.
 | 
			
		||||
        Keeps the same ordering as chosen in the file picker, dedupes by path,
 | 
			
		||||
        and re-selects the last inserted item.
 | 
			
		||||
        """
 | 
			
		||||
        # New items (respect image/video mode)
 | 
			
		||||
        # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
 | 
			
		||||
        new_items = self._paths_from_payload(files_payload)
 | 
			
		||||
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        cur: List[Any] = get_list(gallery)
 | 
			
		||||
        sel = st.get("selected", None)
 | 
			
		||||
        if sel is None:
 | 
			
		||||
            sel = (len(cur) -1) if len(cur)>0 else 0
 | 
			
		||||
        single = bool(st.get("single", False))
 | 
			
		||||
 | 
			
		||||
        # Nothing to add: keep as-is
 | 
			
		||||
        if not new_items:
 | 
			
		||||
            return gr.update(value=cur, selected_index=st.get("selected")), st
 | 
			
		||||
 | 
			
		||||
        # Single-image mode: replace
 | 
			
		||||
        if single:
 | 
			
		||||
            st["items"] = [new_items[-1]]
 | 
			
		||||
            st["selected"] = 0
 | 
			
		||||
            return gr.update(value=st["items"], selected_index=0), st
 | 
			
		||||
 | 
			
		||||
        # ---------- helpers ----------
 | 
			
		||||
        def key_of(it: Any) -> Optional[str]:
 | 
			
		||||
            # Prefer class helper if present
 | 
			
		||||
            if hasattr(self, "_extract_path"):
 | 
			
		||||
                p = self._extract_path(it)  # type: ignore
 | 
			
		||||
            else:
 | 
			
		||||
                p = it if isinstance(it, str) else None
 | 
			
		||||
                if p is None and isinstance(it, dict):
 | 
			
		||||
                    p = it.get("path") or it.get("name")
 | 
			
		||||
                if p is None and hasattr(it, "path"):
 | 
			
		||||
                    try: p = getattr(it, "path")
 | 
			
		||||
                    except Exception: p = None
 | 
			
		||||
                if p is None and hasattr(it, "name"):
 | 
			
		||||
                    try: p = getattr(it, "name")
 | 
			
		||||
                    except Exception: p = None
 | 
			
		||||
            return os.path.abspath(p) if isinstance(p, str) else None
 | 
			
		||||
 | 
			
		||||
        # Dedupe the incoming batch by path, preserve order
 | 
			
		||||
        seen_new = set()
 | 
			
		||||
        incoming: List[Any] = []
 | 
			
		||||
        for it in new_items:
 | 
			
		||||
            k = key_of(it)
 | 
			
		||||
            if k is None or k not in seen_new:
 | 
			
		||||
                incoming.append(it)
 | 
			
		||||
                if k is not None:
 | 
			
		||||
                    seen_new.add(k)
 | 
			
		||||
 | 
			
		||||
        insert_pos = min(sel, len(cur) -1)
 | 
			
		||||
        cur_clean = cur
 | 
			
		||||
        # Build final list and selection
 | 
			
		||||
        merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:]
 | 
			
		||||
        new_sel = insert_pos + len(incoming)   # select the last inserted item
 | 
			
		||||
 | 
			
		||||
        st["items"] = merged
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        record_last_action(st,"add")
 | 
			
		||||
        # print(f"gallery add: set sel {new_sel}")
 | 
			
		||||
        return gr.update(value=merged, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_remove(self, state: Dict[str, Any], gallery) :
 | 
			
		||||
        st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
 | 
			
		||||
        if sel is None or not (0 <= sel < len(items)):
 | 
			
		||||
            return gr.update(value=items, selected_index=st.get("selected")), st
 | 
			
		||||
        items.pop(sel)
 | 
			
		||||
        if not items:
 | 
			
		||||
            st["items"] = []; st["selected"] = None
 | 
			
		||||
            return gr.update(value=[], selected_index=None), st
 | 
			
		||||
        new_sel = min(sel, len(items) - 1)
 | 
			
		||||
        st["items"] = items; st["selected"] = new_sel
 | 
			
		||||
        record_last_action(st,"remove")
 | 
			
		||||
        # print(f"gallery del: new sel {new_sel}")
 | 
			
		||||
        return gr.update(value=items, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
 | 
			
		||||
        st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
 | 
			
		||||
        if sel is None or not (0 <= sel < len(items)):
 | 
			
		||||
            return gr.update(value=items, selected_index=sel), st
 | 
			
		||||
        j = sel + delta
 | 
			
		||||
        if j < 0 or j >= len(items):
 | 
			
		||||
            return gr.update(value=items, selected_index=sel), st
 | 
			
		||||
        items[sel], items[j] = items[j], items[sel]
 | 
			
		||||
        st["items"] = items; st["selected"] = j
 | 
			
		||||
        record_last_action(st,"move")
 | 
			
		||||
        # print(f"gallery move: set sel {j}")
 | 
			
		||||
        return gr.update(value=items, selected_index=j), st
 | 
			
		||||
 | 
			
		||||
    def _on_clear(self, state: Dict[str, Any]) :
 | 
			
		||||
        st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
 | 
			
		||||
        record_last_action(st,"clear")
 | 
			
		||||
        # print(f"Clear all")
 | 
			
		||||
        return gr.update(value=[], selected_index=None), st
 | 
			
		||||
 | 
			
		||||
    def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
 | 
			
		||||
        st = get_state(state); st["single"] = bool(to_single)
 | 
			
		||||
        items: List[Any] = list(st["items"]); sel = st.get("selected", None)
 | 
			
		||||
        if st["single"]:
 | 
			
		||||
            keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None)
 | 
			
		||||
            items = [keep] if keep is not None else []
 | 
			
		||||
            sel = 0 if items else None
 | 
			
		||||
        st["items"] = items; st["selected"] = sel
 | 
			
		||||
 | 
			
		||||
        upload_update = gr.update(file_count=("single" if st["single"] else "multiple"))
 | 
			
		||||
        left_update   = gr.update(visible=not st["single"])
 | 
			
		||||
        right_update  = gr.update(visible=not st["single"])
 | 
			
		||||
        clear_update  = gr.update(visible=not st["single"])
 | 
			
		||||
        gallery_update= gr.update(value=items, selected_index=sel)
 | 
			
		||||
 | 
			
		||||
        return upload_update, left_update, right_update, clear_update, gallery_update, st
 | 
			
		||||
 | 
			
		||||
    # ---------------- build & wire ----------------
 | 
			
		||||
 | 
			
		||||
    def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
 | 
			
		||||
        if parent is not None:
 | 
			
		||||
            with parent:
 | 
			
		||||
                col = self._build_ui(update_form)
 | 
			
		||||
        else:
 | 
			
		||||
            col = self._build_ui(update_form)
 | 
			
		||||
        if not update_form:
 | 
			
		||||
            self._wire_events()
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
    def _build_ui(self, update = False) -> gr.Column:
 | 
			
		||||
        with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
 | 
			
		||||
            self.container = col
 | 
			
		||||
 | 
			
		||||
            self.state = gr.State(dict(self._initial_state))
 | 
			
		||||
 | 
			
		||||
            if update:
 | 
			
		||||
                self.gallery = gr.update(
 | 
			
		||||
                    value=self._initial_state["items"],
 | 
			
		||||
                    selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
                    label=self.label,
 | 
			
		||||
                    show_label=self.show_label,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.gallery = gr.Gallery(
 | 
			
		||||
                    value=self._initial_state["items"],
 | 
			
		||||
                    label=self.label,
 | 
			
		||||
                    height=self.height,
 | 
			
		||||
                    columns=self.columns,
 | 
			
		||||
                    show_label=self.show_label,
 | 
			
		||||
                    preview= True,
 | 
			
		||||
                    # type="pil", # very slow
 | 
			
		||||
                    file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), 
 | 
			
		||||
                    selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # One-line controls
 | 
			
		||||
            exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
 | 
			
		||||
            with gr.Row(equal_height=True, elem_classes=["amg-controls"]):
 | 
			
		||||
                self.upload_btn = gr.UploadButton(
 | 
			
		||||
                    "Set" if self._initial_state["single"] else "Add",
 | 
			
		||||
                    file_types=exts,
 | 
			
		||||
                    file_count=("single" if self._initial_state["single"] else "multiple"),
 | 
			
		||||
                    variant="primary",
 | 
			
		||||
                    size="sm",
 | 
			
		||||
                    min_width=1,
 | 
			
		||||
                )
 | 
			
		||||
                self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
 | 
			
		||||
                self.btn_left   = gr.Button("◀ Left",  size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_right  = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_clear  = gr.Button(" Clear ",   variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
    def _wire_events(self):
 | 
			
		||||
        # Selection: mirror into state and keep gallery.selected_index in sync
 | 
			
		||||
        self.gallery.select(
 | 
			
		||||
            self._on_select,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
 | 
			
		||||
        self.gallery.upload(
 | 
			
		||||
            self._on_upload,
 | 
			
		||||
            inputs=[self.gallery, self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
 | 
			
		||||
        self.gallery.upload(
 | 
			
		||||
            self._on_gallery_change,
 | 
			
		||||
            inputs=[self.gallery, self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Add via UploadButton
 | 
			
		||||
        self.upload_btn.upload(
 | 
			
		||||
            self._on_add,
 | 
			
		||||
            inputs=[self.upload_btn, self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Remove selected
 | 
			
		||||
        self.btn_remove.click(
 | 
			
		||||
            self._on_remove,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Reorder using selected index, keep same item selected
 | 
			
		||||
        self.btn_left.click(
 | 
			
		||||
            lambda st, gallery: self._on_move(-1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
        self.btn_right.click(
 | 
			
		||||
            lambda st, gallery: self._on_move(+1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Clear all
 | 
			
		||||
        self.btn_clear.click(
 | 
			
		||||
            self._on_clear,
 | 
			
		||||
            inputs=[self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # ---------------- public API ----------------
 | 
			
		||||
 | 
			
		||||
    def set_one_image_mode(self, enabled: bool = True):
 | 
			
		||||
        """Toggle single-image mode at runtime."""
 | 
			
		||||
        return (
 | 
			
		||||
            self._on_toggle_single,
 | 
			
		||||
            [gr.State(enabled), self.state],
 | 
			
		||||
            [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_toggable_elements(self):
 | 
			
		||||
        return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state]
 | 
			
		||||
 | 
			
		||||
# import gradio as gr
 | 
			
		||||
 | 
			
		||||
# with gr.Blocks() as demo:
 | 
			
		||||
#     amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8)
 | 
			
		||||
#     amg.mount()
 | 
			
		||||
#     g = amg.gallery
 | 
			
		||||
#     # buttons to switch modes live (optional)
 | 
			
		||||
#     def process(g):
 | 
			
		||||
#         pass
 | 
			
		||||
#     with gr.Row():
 | 
			
		||||
#         gr.Button("toto").click(process, g)
 | 
			
		||||
#         gr.Button("ONE image").click(*amg.set_one_image_mode(True))
 | 
			
		||||
#         gr.Button("MULTI image").click(*amg.set_one_image_mode(False))
 | 
			
		||||
 | 
			
		||||
# demo.launch()
 | 
			
		||||
							
								
								
									
										0
									
								
								shared/inpainting/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								shared/inpainting/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										240
									
								
								shared/inpainting/lanpaint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								shared/inpainting/lanpaint.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,240 @@
 | 
			
		||||
import torch
 | 
			
		||||
from .utils import *
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
 | 
			
		||||
 | 
			
		||||
def _pack_latents(latents):
 | 
			
		||||
    batch_size, num_channels_latents, _, height, width = latents.shape 
 | 
			
		||||
 | 
			
		||||
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
 | 
			
		||||
    latents = latents.permute(0, 2, 4, 1, 3, 5)
 | 
			
		||||
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
 | 
			
		||||
 | 
			
		||||
    return latents
 | 
			
		||||
 | 
			
		||||
def _unpack_latents(latents, height, width, vae_scale_factor=8):
 | 
			
		||||
    batch_size, num_patches, channels = latents.shape
 | 
			
		||||
 | 
			
		||||
    height = 2 * (int(height) // (vae_scale_factor * 2))
 | 
			
		||||
    width = 2 * (int(width) // (vae_scale_factor * 2))
 | 
			
		||||
 | 
			
		||||
    latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
 | 
			
		||||
    latents = latents.permute(0, 3, 1, 4, 2, 5)
 | 
			
		||||
 | 
			
		||||
    latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
 | 
			
		||||
 | 
			
		||||
    return latents
 | 
			
		||||
 | 
			
		||||
class LanPaint():
 | 
			
		||||
    def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
 | 
			
		||||
        self.n_steps = NSteps
 | 
			
		||||
        self.chara_lamb = Lambda
 | 
			
		||||
        self.IS_FLUX = IS_FLUX
 | 
			
		||||
        self.IS_FLOW = IS_FLOW
 | 
			
		||||
        self.step_size = StepSize
 | 
			
		||||
        self.friction = Friction
 | 
			
		||||
        self.chara_beta = Beta
 | 
			
		||||
        self.img_dim_size = None
 | 
			
		||||
    def add_none_dims(self, array):
 | 
			
		||||
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
 | 
			
		||||
        index = (slice(None),) + (None,) * (self.img_dim_size-1)
 | 
			
		||||
        return array[index]
 | 
			
		||||
    def remove_none_dims(self, array):
 | 
			
		||||
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
 | 
			
		||||
        index = (slice(None),) + (0,) * (self.img_dim_size-1)
 | 
			
		||||
        return array[index]
 | 
			
		||||
    def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
 | 
			
		||||
        latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
			
		||||
        noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
			
		||||
        x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
			
		||||
        latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
			
		||||
        self.height = height
 | 
			
		||||
        self.width = width
 | 
			
		||||
        self.vae_scale_factor = vae_scale_factor
 | 
			
		||||
        self.img_dim_size = len(x.shape)
 | 
			
		||||
        self.latent_image = latent_image
 | 
			
		||||
        self.noise = noise
 | 
			
		||||
        if n_steps is None:
 | 
			
		||||
            n_steps = self.n_steps
 | 
			
		||||
        out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
 | 
			
		||||
        out = _pack_latents(out)
 | 
			
		||||
        return out
 | 
			
		||||
    def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG,  x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
 | 
			
		||||
        if IS_FLUX:
 | 
			
		||||
            cfg_BIG = 1.0
 | 
			
		||||
 | 
			
		||||
        def double_denoise(latents, t):
 | 
			
		||||
            latents = _pack_latents(latents)
 | 
			
		||||
            noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
 | 
			
		||||
            if noise_pred == None: return None, None
 | 
			
		||||
            predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
 | 
			
		||||
            predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
 | 
			
		||||
            if true_cfg_scale ==  cfg_BIG:
 | 
			
		||||
                predict_big = predict_std
 | 
			
		||||
            else:
 | 
			
		||||
                predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
 | 
			
		||||
                predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
 | 
			
		||||
            return predict_std, predict_big
 | 
			
		||||
        
 | 
			
		||||
        if len(sigma.shape) == 0:
 | 
			
		||||
            sigma = torch.tensor([sigma.item()])
 | 
			
		||||
        latent_mask = 1 - latent_mask
 | 
			
		||||
        if IS_FLUX or IS_FLOW:
 | 
			
		||||
            Flow_t = sigma
 | 
			
		||||
            abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
 | 
			
		||||
            VE_Sigma = Flow_t / (1 - Flow_t)
 | 
			
		||||
            #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
 | 
			
		||||
        else:
 | 
			
		||||
            VE_Sigma = sigma 
 | 
			
		||||
            abt = 1/( 1+VE_Sigma**2 )
 | 
			
		||||
            Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5  )
 | 
			
		||||
        # VE_Sigma, abt, Flow_t = current_times
 | 
			
		||||
        current_times =  (VE_Sigma, abt, Flow_t)
 | 
			
		||||
        
 | 
			
		||||
        step_size = self.step_size * (1 - abt)
 | 
			
		||||
        step_size = self.add_none_dims(step_size)
 | 
			
		||||
        # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
 | 
			
		||||
        # This is the replace step
 | 
			
		||||
        # x = x * (1 - latent_mask) +  self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
 | 
			
		||||
 | 
			
		||||
        noisy_image  = self.latent_image  * (1.0 - sigma) + self.noise * sigma 
 | 
			
		||||
        x = x * (1 - latent_mask) +  noisy_image * latent_mask
 | 
			
		||||
 | 
			
		||||
        if IS_FLUX or IS_FLOW:
 | 
			
		||||
            x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
 | 
			
		||||
        else:
 | 
			
		||||
            x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
 | 
			
		||||
 | 
			
		||||
        ############ LanPaint Iterations Start ###############
 | 
			
		||||
        # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
 | 
			
		||||
        args = None
 | 
			
		||||
        for i in range(n_steps):
 | 
			
		||||
            score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
 | 
			
		||||
            if score_func is None: return None
 | 
			
		||||
            x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)  
 | 
			
		||||
        if IS_FLUX or IS_FLOW:
 | 
			
		||||
            x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
 | 
			
		||||
        else:
 | 
			
		||||
            x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
 | 
			
		||||
        ############ LanPaint Iterations End ###############
 | 
			
		||||
        # out is x_0
 | 
			
		||||
        # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
 | 
			
		||||
        # out = out * (1-latent_mask) + self.latent_image * latent_mask
 | 
			
		||||
        # return out
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
 | 
			
		||||
        
 | 
			
		||||
        lamb = self.chara_lamb
 | 
			
		||||
        if self.IS_FLUX or self.IS_FLOW:
 | 
			
		||||
            # compute t for flow model, with a small epsilon compensating for numerical error.
 | 
			
		||||
            x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
 | 
			
		||||
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
 | 
			
		||||
            if x_0 is None: return None
 | 
			
		||||
        else:
 | 
			
		||||
            x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
 | 
			
		||||
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
 | 
			
		||||
            if x_0 is None: return None
 | 
			
		||||
 | 
			
		||||
        score_x = -(x_t - x_0)
 | 
			
		||||
        score_y =  - (1 + lamb) * ( x_t - y )  + lamb * (x_t - x_0_BIG)  
 | 
			
		||||
        return score_x * (1 - mask) + score_y * mask
 | 
			
		||||
    def sigma_x(self, abt):
 | 
			
		||||
        # the time scale for the x_t update
 | 
			
		||||
        return abt**0
 | 
			
		||||
    def sigma_y(self, abt):
 | 
			
		||||
        beta = self.chara_beta * abt ** 0
 | 
			
		||||
        return beta
 | 
			
		||||
 | 
			
		||||
    def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
 | 
			
		||||
        # prepare the step size and time parameters
 | 
			
		||||
        with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
 | 
			
		||||
            step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
 | 
			
		||||
            sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
 | 
			
		||||
        # print('mask',mask.device)
 | 
			
		||||
        if torch.mean(dtx) <= 0.:
 | 
			
		||||
            return x_t, args
 | 
			
		||||
        # -------------------------------------------------------------------------
 | 
			
		||||
        # Compute the Langevin dynamics update in variance perserving notation
 | 
			
		||||
        # -------------------------------------------------------------------------
 | 
			
		||||
        #x0 = self.x0_evalutation(x_t, score, sigma, args)
 | 
			
		||||
        #C = abt**0.5 * x0 / (1-abt)
 | 
			
		||||
        A = A_x * (1-mask) + A_y * mask
 | 
			
		||||
        D = D_x * (1-mask) + D_y * mask
 | 
			
		||||
        dt = dtx * (1-mask) + dty * mask
 | 
			
		||||
        Gamma = Gamma_x * (1-mask) + Gamma_y * mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        def Coef_C(x_t):
 | 
			
		||||
            x0 = self.x0_evalutation(x_t, score, sigma, args)
 | 
			
		||||
            C = (abt**0.5 * x0  - x_t )/ (1-abt) + A * x_t
 | 
			
		||||
            return C
 | 
			
		||||
        def advance_time(x_t, v, dt, Gamma, A, C, D):
 | 
			
		||||
            dtype = x_t.dtype
 | 
			
		||||
            with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
 | 
			
		||||
                osc = StochasticHarmonicOscillator(Gamma, A, C, D )
 | 
			
		||||
                x_t, v = osc.dynamics(x_t, v, dt )
 | 
			
		||||
            x_t = x_t.to(dtype)
 | 
			
		||||
            v = v.to(dtype)
 | 
			
		||||
            return x_t, v
 | 
			
		||||
        if args is None:
 | 
			
		||||
            #v = torch.zeros_like(x_t)
 | 
			
		||||
            v = None
 | 
			
		||||
            C = Coef_C(x_t)
 | 
			
		||||
            #print(torch.squeeze(dtx), torch.squeeze(dty))
 | 
			
		||||
            x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
 | 
			
		||||
        else:
 | 
			
		||||
            v, C = args
 | 
			
		||||
 | 
			
		||||
            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
 | 
			
		||||
 | 
			
		||||
            C_new = Coef_C(x_t)
 | 
			
		||||
            v = v + Gamma**0.5 * ( C_new - C) *dt
 | 
			
		||||
 | 
			
		||||
            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
 | 
			
		||||
 | 
			
		||||
            C = C_new
 | 
			
		||||
  
 | 
			
		||||
        return x_t, (v, C)
 | 
			
		||||
 | 
			
		||||
    def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
 | 
			
		||||
        # -------------------------------------------------------------------------
 | 
			
		||||
        # Unpack current times parameters (sigma and abt)
 | 
			
		||||
        sigma, abt, flow_t = current_times
 | 
			
		||||
        sigma = self.add_none_dims(sigma)
 | 
			
		||||
        abt = self.add_none_dims(abt)
 | 
			
		||||
        # Compute time step (dtx, dty) for x and y branches.
 | 
			
		||||
        dtx = 2 * step_size * sigma_x
 | 
			
		||||
        dty = 2 * step_size * sigma_y
 | 
			
		||||
        
 | 
			
		||||
        # -------------------------------------------------------------------------
 | 
			
		||||
        # Define friction parameter Gamma_hat for each branch.
 | 
			
		||||
        # Using dtx**0 provides a tensor of the proper device/dtype.
 | 
			
		||||
 | 
			
		||||
        Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
 | 
			
		||||
        Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
 | 
			
		||||
        #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
 | 
			
		||||
        # adjust dt to match denoise-addnoise steps sizes
 | 
			
		||||
        Gamma_hat_x /= 2.
 | 
			
		||||
        Gamma_hat_y /= 2.
 | 
			
		||||
        A_t_x = (1) / ( 1 - abt ) * dtx / 2
 | 
			
		||||
        A_t_y =  (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        A_x = A_t_x / (dtx/2)
 | 
			
		||||
        A_y = A_t_y / (dty/2)
 | 
			
		||||
        Gamma_x = Gamma_hat_x / (dtx/2)
 | 
			
		||||
        Gamma_y = Gamma_hat_y / (dty/2)
 | 
			
		||||
 | 
			
		||||
        #D_x = (2 * (1 + sigma**2) )**0.5
 | 
			
		||||
        #D_y = (2 * (1 + sigma**2) )**0.5
 | 
			
		||||
        D_x = (2 * abt**0 )**0.5
 | 
			
		||||
        D_y = (2 * abt**0 )**0.5
 | 
			
		||||
        return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def x0_evalutation(self, x_t, score, sigma, args):
 | 
			
		||||
        x0 = x_t + score(x_t)
 | 
			
		||||
        return x0
 | 
			
		||||
							
								
								
									
										301
									
								
								shared/inpainting/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										301
									
								
								shared/inpainting/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,301 @@
 | 
			
		||||
import torch
 | 
			
		||||
def epxm1_x(x):
 | 
			
		||||
    # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
 | 
			
		||||
    result = torch.special.expm1(x) / x
 | 
			
		||||
    # replace NaN or inf values with 0
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    mask = torch.abs(x) < 1e-2
 | 
			
		||||
    result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
 | 
			
		||||
    return result
 | 
			
		||||
def epxm1mx_x2(x):
 | 
			
		||||
    # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
 | 
			
		||||
    result = (torch.special.expm1(x) - x) / x**2
 | 
			
		||||
    # replace NaN or inf values with 0
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    mask = torch.abs(x**2) < 1e-2
 | 
			
		||||
    result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def expm1mxmhx2_x3(x):
 | 
			
		||||
    # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
 | 
			
		||||
    result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
 | 
			
		||||
    # replace NaN or inf values with 0
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    mask = torch.abs(x**3) < 1e-2
 | 
			
		||||
    result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def exp_1mcosh_GD(gamma_t, delta):
 | 
			
		||||
    """
 | 
			
		||||
    Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
 | 
			
		||||
    
 | 
			
		||||
    Parameters:
 | 
			
		||||
    gamma_t: Γ*t term (could be a scalar or tensor)
 | 
			
		||||
    delta: Δ term (could be a scalar or tensor)
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
    Result of the computation with numerical stability handling
 | 
			
		||||
    """
 | 
			
		||||
    # Main computation
 | 
			
		||||
    is_positive = delta > 0
 | 
			
		||||
    sqrt_abs_delta = torch.sqrt(torch.abs(delta))
 | 
			
		||||
    gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
 | 
			
		||||
    numerator_pos =  torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
 | 
			
		||||
    numerator_neg = torch.exp(-gamma_t) * ( 1 -  torch.cos(gamma_t * sqrt_abs_delta ) )
 | 
			
		||||
    numerator = torch.where(is_positive, numerator_pos, numerator_neg)
 | 
			
		||||
    result =  numerator / (delta * gamma_t**2 )
 | 
			
		||||
    # Handle NaN/inf cases
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    # Handle numerical instability for small delta
 | 
			
		||||
    mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
 | 
			
		||||
    taylor = ( -0.5  - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
 | 
			
		||||
    result = torch.where(mask, taylor, result)
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def exp_sinh_GsqrtD(gamma_t, delta):
 | 
			
		||||
    """
 | 
			
		||||
    Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
 | 
			
		||||
 | 
			
		||||
    Parameters:
 | 
			
		||||
    gamma_t: Γ*t term (could be a scalar or tensor)
 | 
			
		||||
    delta: Δ term (could be a scalar or tensor)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
    Result of the computation with numerical stability handling
 | 
			
		||||
    """
 | 
			
		||||
    # Main computation
 | 
			
		||||
    is_positive = delta > 0
 | 
			
		||||
    sqrt_abs_delta = torch.sqrt(torch.abs(delta))
 | 
			
		||||
    gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
 | 
			
		||||
    numerator_pos =  (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
 | 
			
		||||
    denominator_pos = gamma_t_sqrt_delta
 | 
			
		||||
    result_pos = numerator_pos / gamma_t_sqrt_delta
 | 
			
		||||
    result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
 | 
			
		||||
 | 
			
		||||
    # Taylor expansion for small gamma_t_sqrt_delta
 | 
			
		||||
    mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
 | 
			
		||||
    taylor = ( 1  + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
 | 
			
		||||
    result_pos = torch.where(mask, taylor, result_pos)
 | 
			
		||||
 | 
			
		||||
    # Handle negative delta
 | 
			
		||||
    result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
 | 
			
		||||
    result = torch.where(is_positive, result_pos, result_neg)
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def exp_cosh(gamma_t, delta):
 | 
			
		||||
    """
 | 
			
		||||
    Compute e^(-Γt) * cosh(Γt√Δ)
 | 
			
		||||
 | 
			
		||||
    Parameters:
 | 
			
		||||
    gamma_t: Γ*t term (could be a scalar or tensor)
 | 
			
		||||
    delta: Δ term (could be a scalar or tensor)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
    Result of the computation with numerical stability handling
 | 
			
		||||
    """
 | 
			
		||||
    exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
 | 
			
		||||
    result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
 | 
			
		||||
    return result
 | 
			
		||||
def exp_sinh_sqrtD(gamma_t, delta):
 | 
			
		||||
    """
 | 
			
		||||
    Compute e^(-Γt) * sinh(Γt√Δ) / √Δ
 | 
			
		||||
    Parameters:
 | 
			
		||||
    gamma_t: Γ*t term (could be a scalar or tensor)
 | 
			
		||||
    delta: Δ term (could be a scalar or tensor)
 | 
			
		||||
    Returns:
 | 
			
		||||
    Result of the computation with numerical stability handling
 | 
			
		||||
    """
 | 
			
		||||
    exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
 | 
			
		||||
    result = gamma_t * exp_sinh_GsqrtD_result
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def zeta1(gamma_t, delta):
 | 
			
		||||
    # Compute hyperbolic terms and exponential
 | 
			
		||||
    half_gamma_t = gamma_t / 2
 | 
			
		||||
    exp_cosh_term = exp_cosh(half_gamma_t, delta)
 | 
			
		||||
    exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    # Main computation
 | 
			
		||||
    numerator = 1 - (exp_cosh_term + exp_sinh_term)
 | 
			
		||||
    denominator = gamma_t * (1 - delta) / 4
 | 
			
		||||
    result = 1 - numerator / denominator
 | 
			
		||||
    
 | 
			
		||||
    # Handle numerical instability
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    
 | 
			
		||||
    # Taylor expansion for small x (similar to your epxm1Dx approach)
 | 
			
		||||
    mask = torch.abs(denominator) < 5e-3
 | 
			
		||||
    term1 = epxm1_x(-gamma_t)
 | 
			
		||||
    term2 = epxm1mx_x2(-gamma_t)
 | 
			
		||||
    term3 = expm1mxmhx2_x3(-gamma_t)
 | 
			
		||||
    taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
 | 
			
		||||
    result = torch.where(mask, taylor, result)
 | 
			
		||||
    
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def exp_cosh_minus_terms(gamma_t, delta):
 | 
			
		||||
    """
 | 
			
		||||
    Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ))
 | 
			
		||||
    
 | 
			
		||||
    Parameters:
 | 
			
		||||
    gamma_t: Γ*t term (could be a scalar or tensor)
 | 
			
		||||
    delta: Δ term (could be a scalar or tensor)
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
    Result of the computation with numerical stability handling
 | 
			
		||||
    """
 | 
			
		||||
    exp_term = torch.exp(-gamma_t)
 | 
			
		||||
    # Compute individual terms
 | 
			
		||||
    exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
 | 
			
		||||
    exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta)  # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
 | 
			
		||||
    
 | 
			
		||||
    #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
 | 
			
		||||
    # Main computation
 | 
			
		||||
    numerator = exp_cosh_term - exp_cosh_delta_term
 | 
			
		||||
    denominator = gamma_t * (1 - delta)
 | 
			
		||||
    
 | 
			
		||||
    result = numerator / denominator
 | 
			
		||||
    
 | 
			
		||||
    # Handle numerical instability
 | 
			
		||||
    result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
 | 
			
		||||
    
 | 
			
		||||
    # Taylor expansion for small gamma_t and delta near 1
 | 
			
		||||
    mask = (torch.abs(denominator) < 1e-1)
 | 
			
		||||
    exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
 | 
			
		||||
    taylor = (
 | 
			
		||||
       gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) 
 | 
			
		||||
       - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
 | 
			
		||||
    )
 | 
			
		||||
    result = torch.where(mask, taylor, result)
 | 
			
		||||
    
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def zeta2(gamma_t, delta):
 | 
			
		||||
    half_gamma_t = gamma_t / 2
 | 
			
		||||
    return exp_sinh_GsqrtD(half_gamma_t, delta)
 | 
			
		||||
 | 
			
		||||
def sig11(gamma_t, delta):
 | 
			
		||||
    return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Zcoefs(gamma_t, delta):
 | 
			
		||||
    Zeta1 = zeta1(gamma_t, delta)
 | 
			
		||||
    Zeta2 = zeta2(gamma_t, delta)
 | 
			
		||||
    
 | 
			
		||||
    sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
 | 
			
		||||
    amplitude = torch.sqrt(sq_total)
 | 
			
		||||
    Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
 | 
			
		||||
    Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta)  / sig11(gamma_t, delta)  ) ** 0.5 
 | 
			
		||||
    #cterm = exp_cosh_minus_terms(gamma_t, delta)
 | 
			
		||||
    #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
 | 
			
		||||
    #Zcoef3 = 2 * torch.sqrt(  cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
 | 
			
		||||
    Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
 | 
			
		||||
 | 
			
		||||
    return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
 | 
			
		||||
 | 
			
		||||
def Zcoefs_asymp(gamma_t, delta):
 | 
			
		||||
    A_t = (gamma_t * (1 - delta) )/4
 | 
			
		||||
    return epxm1_x(- 2 * A_t)
 | 
			
		||||
 | 
			
		||||
class StochasticHarmonicOscillator:
 | 
			
		||||
    """
 | 
			
		||||
    Simulates a stochastic harmonic oscillator governed by the equations:
 | 
			
		||||
        dy(t) = q(t) dt
 | 
			
		||||
        dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
 | 
			
		||||
 | 
			
		||||
    Also define v(t) = q(t) / √Γ, which is numerically more stable.
 | 
			
		||||
        
 | 
			
		||||
    Where:
 | 
			
		||||
        y(t) - Position variable
 | 
			
		||||
        q(t) - Velocity variable
 | 
			
		||||
        Γ - Damping coefficient
 | 
			
		||||
        A - Harmonic potential strength
 | 
			
		||||
        C - Constant force term
 | 
			
		||||
        D - Noise amplitude
 | 
			
		||||
        dw(t) - Wiener process (Brownian motion)
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, Gamma, A, C, D):
 | 
			
		||||
        self.Gamma = Gamma
 | 
			
		||||
        self.A = A
 | 
			
		||||
        self.C = C
 | 
			
		||||
        self.D = D
 | 
			
		||||
        self.Delta = 1 - 4 * A / Gamma
 | 
			
		||||
    def sig11(self, gamma_t, delta):
 | 
			
		||||
        return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
 | 
			
		||||
    def sig22(self, gamma_t, delta):
 | 
			
		||||
        return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) 
 | 
			
		||||
    def dynamics(self, y0, v0, t):
 | 
			
		||||
        """
 | 
			
		||||
        Calculates the position and velocity variables at time t.
 | 
			
		||||
 | 
			
		||||
        Parameters:
 | 
			
		||||
            y0 (float): Initial position
 | 
			
		||||
            v0 (float): Initial velocity v(0) = q(0) / √Γ
 | 
			
		||||
            t (float): Time at which to evaluate the dynamics
 | 
			
		||||
        Returns:
 | 
			
		||||
            tuple: (y(t), v(t))
 | 
			
		||||
        """
 | 
			
		||||
        
 | 
			
		||||
        dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
 | 
			
		||||
        Delta = self.Delta + dummyzero
 | 
			
		||||
        Gamma_hat = self.Gamma * t + dummyzero
 | 
			
		||||
        A = self.A + dummyzero
 | 
			
		||||
        C = self.C + dummyzero
 | 
			
		||||
        D = self.D + dummyzero
 | 
			
		||||
        Gamma = self.Gamma + dummyzero
 | 
			
		||||
        zeta_1 = zeta1( Gamma_hat, Delta) 
 | 
			
		||||
        zeta_2 = zeta2( Gamma_hat, Delta)
 | 
			
		||||
        EE = 1 - Gamma_hat * zeta_2
 | 
			
		||||
 | 
			
		||||
        if v0 is None:
 | 
			
		||||
            v0 = torch.randn_like(y0) * D / 2 ** 0.5
 | 
			
		||||
            #v0 = (C - A * y0)/Gamma**0.5
 | 
			
		||||
        
 | 
			
		||||
        # Calculate mean position and velocity
 | 
			
		||||
        term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
 | 
			
		||||
        y_mean = term1 + y0
 | 
			
		||||
        v_mean =  (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
 | 
			
		||||
        
 | 
			
		||||
        cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
 | 
			
		||||
        cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
 | 
			
		||||
        cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
 | 
			
		||||
 | 
			
		||||
        # sample new position and velocity with multivariate normal distribution
 | 
			
		||||
 | 
			
		||||
        batch_shape = y0.shape
 | 
			
		||||
        cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
 | 
			
		||||
        cov_matrix[..., 0, 0] = cov_yy
 | 
			
		||||
        cov_matrix[..., 0, 1] = cov_yv
 | 
			
		||||
        cov_matrix[..., 1, 0] = cov_yv  # symmetric
 | 
			
		||||
        cov_matrix[..., 1, 1] = cov_vv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        # Compute the Cholesky decomposition to get scale_tril
 | 
			
		||||
        #scale_tril = torch.linalg.cholesky(cov_matrix)
 | 
			
		||||
        scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
 | 
			
		||||
        tol = 1e-8
 | 
			
		||||
        cov_yy = torch.clamp( cov_yy, min = tol )
 | 
			
		||||
        sd_yy = torch.sqrt( cov_yy )
 | 
			
		||||
        inv_sd_yy = 1/(sd_yy)
 | 
			
		||||
 | 
			
		||||
        scale_tril[..., 0, 0] = sd_yy
 | 
			
		||||
        scale_tril[..., 0, 1] = 0.
 | 
			
		||||
        scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
 | 
			
		||||
        scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
 | 
			
		||||
        # check if it matches torch.linalg.
 | 
			
		||||
        #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
 | 
			
		||||
        # Sample correlated noise from multivariate normal
 | 
			
		||||
        mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
 | 
			
		||||
        mean[..., 0] = y_mean
 | 
			
		||||
        mean[..., 1] = v_mean
 | 
			
		||||
        new_yv = torch.distributions.MultivariateNormal(
 | 
			
		||||
            loc=mean,
 | 
			
		||||
            scale_tril=scale_tril
 | 
			
		||||
        ).sample()
 | 
			
		||||
 | 
			
		||||
        return new_yv[...,0], new_yv[...,1]
 | 
			
		||||
@ -232,6 +232,9 @@ def save_video(tensor,
 | 
			
		||||
                retry=5):
 | 
			
		||||
    """Save tensor as video with configurable codec and container options."""
 | 
			
		||||
        
 | 
			
		||||
    if torch.is_tensor(tensor) and len(tensor.shape) == 4:
 | 
			
		||||
        tensor = tensor.unsqueeze(0)
 | 
			
		||||
        
 | 
			
		||||
    suffix = f'.{container}'
 | 
			
		||||
    cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
 | 
			
		||||
    if not cache_file.endswith(suffix):
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										110
									
								
								shared/utils/download.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								shared/utils/download.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
			
		||||
import sys, time
 | 
			
		||||
 | 
			
		||||
# Global variables to track download progress
 | 
			
		||||
_start_time = None
 | 
			
		||||
_last_time = None
 | 
			
		||||
_last_downloaded = 0
 | 
			
		||||
_speed_history = []
 | 
			
		||||
_update_interval = 0.5  # Update speed every 0.5 seconds
 | 
			
		||||
 | 
			
		||||
def progress_hook(block_num, block_size, total_size, filename=None):
 | 
			
		||||
    """
 | 
			
		||||
    Simple progress bar hook for urlretrieve
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        block_num: Number of blocks downloaded so far
 | 
			
		||||
        block_size: Size of each block in bytes
 | 
			
		||||
        total_size: Total size of the file in bytes
 | 
			
		||||
        filename: Name of the file being downloaded (optional)
 | 
			
		||||
    """
 | 
			
		||||
    global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
 | 
			
		||||
    
 | 
			
		||||
    current_time = time.time()
 | 
			
		||||
    downloaded = block_num * block_size
 | 
			
		||||
    
 | 
			
		||||
    # Initialize timing on first call
 | 
			
		||||
    if _start_time is None or block_num == 0:
 | 
			
		||||
        _start_time = current_time
 | 
			
		||||
        _last_time = current_time
 | 
			
		||||
        _last_downloaded = 0
 | 
			
		||||
        _speed_history = []
 | 
			
		||||
    
 | 
			
		||||
    # Calculate download speed only at specified intervals
 | 
			
		||||
    speed = 0
 | 
			
		||||
    if current_time - _last_time >= _update_interval:
 | 
			
		||||
        if _last_time > 0:
 | 
			
		||||
            current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
 | 
			
		||||
            _speed_history.append(current_speed)
 | 
			
		||||
            # Keep only last 5 speed measurements for smoothing
 | 
			
		||||
            if len(_speed_history) > 5:
 | 
			
		||||
                _speed_history.pop(0)
 | 
			
		||||
            # Average the recent speeds for smoother display
 | 
			
		||||
            speed = sum(_speed_history) / len(_speed_history)
 | 
			
		||||
        
 | 
			
		||||
        _last_time = current_time
 | 
			
		||||
        _last_downloaded = downloaded
 | 
			
		||||
    elif _speed_history:
 | 
			
		||||
        # Use the last calculated average speed
 | 
			
		||||
        speed = sum(_speed_history) / len(_speed_history)
 | 
			
		||||
    # Format file sizes and speed
 | 
			
		||||
    def format_bytes(bytes_val):
 | 
			
		||||
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
			
		||||
            if bytes_val < 1024:
 | 
			
		||||
                return f"{bytes_val:.1f}{unit}"
 | 
			
		||||
            bytes_val /= 1024
 | 
			
		||||
        return f"{bytes_val:.1f}TB"
 | 
			
		||||
    
 | 
			
		||||
    file_display = filename if filename else "Unknown file"
 | 
			
		||||
    
 | 
			
		||||
    if total_size <= 0:
 | 
			
		||||
        # If total size is unknown, show downloaded bytes
 | 
			
		||||
        speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
			
		||||
        line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
 | 
			
		||||
        # Clear any trailing characters by padding with spaces
 | 
			
		||||
        sys.stdout.write(line.ljust(80))
 | 
			
		||||
        sys.stdout.flush()
 | 
			
		||||
        return
 | 
			
		||||
    
 | 
			
		||||
    downloaded = block_num * block_size
 | 
			
		||||
    percent = min(100, (downloaded / total_size) * 100)
 | 
			
		||||
    
 | 
			
		||||
    # Create progress bar (40 characters wide to leave room for other info)
 | 
			
		||||
    bar_length = 40
 | 
			
		||||
    filled = int(bar_length * percent / 100)
 | 
			
		||||
    bar = '█' * filled + '░' * (bar_length - filled)
 | 
			
		||||
    
 | 
			
		||||
    # Format file sizes and speed
 | 
			
		||||
    def format_bytes(bytes_val):
 | 
			
		||||
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
			
		||||
            if bytes_val < 1024:
 | 
			
		||||
                return f"{bytes_val:.1f}{unit}"
 | 
			
		||||
            bytes_val /= 1024
 | 
			
		||||
        return f"{bytes_val:.1f}TB"
 | 
			
		||||
    
 | 
			
		||||
    speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
			
		||||
    
 | 
			
		||||
    # Display progress with filename first
 | 
			
		||||
    line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
 | 
			
		||||
    # Clear any trailing characters by padding with spaces
 | 
			
		||||
    sys.stdout.write(line.ljust(100))
 | 
			
		||||
    sys.stdout.flush()
 | 
			
		||||
    
 | 
			
		||||
    # Print newline when complete
 | 
			
		||||
    if percent >= 100:
 | 
			
		||||
        print()
 | 
			
		||||
 | 
			
		||||
# Wrapper function to include filename in progress hook
 | 
			
		||||
def create_progress_hook(filename):
 | 
			
		||||
    """Creates a progress hook with the filename included"""
 | 
			
		||||
    global _start_time, _last_time, _last_downloaded, _speed_history
 | 
			
		||||
    # Reset timing variables for new download
 | 
			
		||||
    _start_time = None
 | 
			
		||||
    _last_time = None
 | 
			
		||||
    _last_downloaded = 0
 | 
			
		||||
    _speed_history = []
 | 
			
		||||
    
 | 
			
		||||
    def hook(block_num, block_size, total_size):
 | 
			
		||||
        return progress_hook(block_num, block_size, total_size, filename)
 | 
			
		||||
    return hook
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,7 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me
 | 
			
		||||
    
 | 
			
		||||
    return  loras_list_mult_choices_nums, slists_dict, ""
 | 
			
		||||
 | 
			
		||||
def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ):
 | 
			
		||||
def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None):
 | 
			
		||||
    from mmgp import offload
 | 
			
		||||
    sz = len(slists_dict["phase1"])
 | 
			
		||||
    slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz)  ]
 | 
			
		||||
@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
 | 
			
		||||
def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
 | 
			
		||||
    total_num_steps = len(timesteps)
 | 
			
		||||
    model_switch_step = model_switch_step2 = None
 | 
			
		||||
    for i, t in enumerate(timesteps):
 | 
			
		||||
        if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@ import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import warnings
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
@ -257,7 +256,6 @@ VIDEO_READER_BACKENDS = {
 | 
			
		||||
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache(maxsize=1)
 | 
			
		||||
def get_video_reader_backend() -> str:
 | 
			
		||||
    if FORCE_QWENVL_VIDEO_READER is not None:
 | 
			
		||||
        video_reader_backend = FORCE_QWENVL_VIDEO_READER
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import os.path as osp
 | 
			
		||||
@ -18,11 +17,12 @@ import os
 | 
			
		||||
import tempfile
 | 
			
		||||
import subprocess
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
video_info_cache = []
 | 
			
		||||
def seed_everything(seed: int):
 | 
			
		||||
    random.seed(seed)
 | 
			
		||||
    np.random.seed(seed)
 | 
			
		||||
@ -32,6 +32,14 @@ def seed_everything(seed: int):
 | 
			
		||||
    if torch.backends.mps.is_available():
 | 
			
		||||
        torch.mps.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
def has_video_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".mp4"]
 | 
			
		||||
 | 
			
		||||
def has_image_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
 | 
			
		||||
 | 
			
		||||
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
 | 
			
		||||
    import math
 | 
			
		||||
 | 
			
		||||
@ -77,7 +85,9 @@ def truncate_for_filesystem(s, max_bytes=255):
 | 
			
		||||
        else: r = m - 1
 | 
			
		||||
    return s[:l]
 | 
			
		||||
 | 
			
		||||
@lru_cache(maxsize=100)
 | 
			
		||||
def get_video_info(video_path):
 | 
			
		||||
    global video_info_cache
 | 
			
		||||
    import cv2
 | 
			
		||||
    cap = cv2.VideoCapture(video_path)
 | 
			
		||||
    
 | 
			
		||||
@ -92,7 +102,7 @@ def get_video_info(video_path):
 | 
			
		||||
    
 | 
			
		||||
    return fps, width, height, frame_count
 | 
			
		||||
 | 
			
		||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
 | 
			
		||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None,  return_PIL = True) -> torch.Tensor:
 | 
			
		||||
    """Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
 | 
			
		||||
    cap = cv2.VideoCapture(file_name)
 | 
			
		||||
    
 | 
			
		||||
@ -100,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
 | 
			
		||||
        raise ValueError(f"Cannot open video: {file_name}")
 | 
			
		||||
    
 | 
			
		||||
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 | 
			
		||||
    
 | 
			
		||||
    fps = round(cap.get(cv2.CAP_PROP_FPS))
 | 
			
		||||
    if target_fps is not None:
 | 
			
		||||
        frame_no = round(target_fps * frame_no /fps)
 | 
			
		||||
 | 
			
		||||
    # Handle out of bounds
 | 
			
		||||
    if frame_no >= total_frames or frame_no < 0:
 | 
			
		||||
        if return_last_if_missing:
 | 
			
		||||
@ -173,9 +186,15 @@ def remove_background(img, session=None):
 | 
			
		||||
def convert_image_to_tensor(image):
 | 
			
		||||
    return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
 | 
			
		||||
 | 
			
		||||
def convert_tensor_to_image(t, frame_no = -1):    
 | 
			
		||||
    t = t[:, frame_no] if frame_no >= 0 else t
 | 
			
		||||
    return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
 | 
			
		||||
    if len(t.shape) == 4:
 | 
			
		||||
        t = t[:, frame_no] 
 | 
			
		||||
    if t.shape[0]== 1:
 | 
			
		||||
        t = t.expand(3,-1,-1)
 | 
			
		||||
    if mask_levels:
 | 
			
		||||
        return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
    else:
 | 
			
		||||
        return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
 | 
			
		||||
def save_image(tensor_image, name, frame_no = -1):
 | 
			
		||||
    convert_tensor_to_image(tensor_image, frame_no).save(name)
 | 
			
		||||
@ -186,6 +205,14 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d
 | 
			
		||||
    frame_width =  int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
 | 
			
		||||
    return frame_height, frame_width  
 | 
			
		||||
 | 
			
		||||
def rgb_bw_to_rgba_mask(img, thresh=127):
 | 
			
		||||
    a = img.convert('L').point(lambda p: 255 if p > thresh else 0)  # alpha
 | 
			
		||||
    out = Image.new('RGBA', img.size, (255, 255, 255, 0))           # white, transparent
 | 
			
		||||
    out.putalpha(a)                                                 # white where alpha=255
 | 
			
		||||
    return out
 | 
			
		||||
                        
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def  get_outpainting_frame_location(final_height, final_width,  outpainting_dims, block_size = 8):
 | 
			
		||||
    outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
			
		||||
    raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
 | 
			
		||||
@ -205,30 +232,64 @@ def  get_outpainting_frame_location(final_height, final_width,  outpainting_dims
 | 
			
		||||
    if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
 | 
			
		||||
    return height, width, margin_top, margin_left
 | 
			
		||||
 | 
			
		||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
 | 
			
		||||
    if fit_into_canvas == None:
 | 
			
		||||
def rescale_and_crop(img, w, h):
 | 
			
		||||
    ow, oh = img.size
 | 
			
		||||
    target_ratio = w / h
 | 
			
		||||
    orig_ratio = ow / oh
 | 
			
		||||
    
 | 
			
		||||
    if orig_ratio > target_ratio:
 | 
			
		||||
        # Crop width first
 | 
			
		||||
        nw = int(oh * target_ratio)
 | 
			
		||||
        img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
 | 
			
		||||
    else:
 | 
			
		||||
        # Crop height first
 | 
			
		||||
        nh = int(ow / target_ratio)
 | 
			
		||||
        img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
 | 
			
		||||
    
 | 
			
		||||
    return img.resize((w, h), Image.LANCZOS)
 | 
			
		||||
 | 
			
		||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas,  block_size = 16):
 | 
			
		||||
    if fit_into_canvas == None or fit_into_canvas == 2:
 | 
			
		||||
        # return image_height, image_width
 | 
			
		||||
        return canvas_height, canvas_width
 | 
			
		||||
    if fit_into_canvas:
 | 
			
		||||
    if fit_into_canvas == 1:
 | 
			
		||||
        scale1  = min(canvas_height / image_height, canvas_width / image_width)
 | 
			
		||||
        scale2  = min(canvas_width / image_height, canvas_height / image_width)
 | 
			
		||||
        scale = max(scale1, scale2) 
 | 
			
		||||
    else:
 | 
			
		||||
    else: #0 or #2 (crop)
 | 
			
		||||
        scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
 | 
			
		||||
 | 
			
		||||
    new_height = round( image_height * scale / block_size) * block_size
 | 
			
		||||
    new_width = round( image_width * scale / block_size) * block_size
 | 
			
		||||
    return new_height, new_width
 | 
			
		||||
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
 | 
			
		||||
def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
 | 
			
		||||
    if fit_crop:
 | 
			
		||||
        image = rescale_and_crop(image, canvas_width, canvas_height)
 | 
			
		||||
        new_width, new_height = image.size  
 | 
			
		||||
    else:
 | 
			
		||||
        image_width, image_height = image.size
 | 
			
		||||
        new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
 | 
			
		||||
        image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
    return image, new_height, new_width
 | 
			
		||||
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
 | 
			
		||||
    if rm_background:
 | 
			
		||||
        session = new_session() 
 | 
			
		||||
 | 
			
		||||
    output_list =[]
 | 
			
		||||
    output_mask_list =[]
 | 
			
		||||
    for i, img in enumerate(img_list):
 | 
			
		||||
        width, height =  img.size 
 | 
			
		||||
 | 
			
		||||
        if fit_into_canvas:
 | 
			
		||||
        resized_mask = None
 | 
			
		||||
        if any_background_ref == 1 and i==0 or any_background_ref == 2:
 | 
			
		||||
            if outpainting_dims is not None and background_ref_outpainted:
 | 
			
		||||
                resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
 | 
			
		||||
            elif img.size != (budget_width, budget_height):
 | 
			
		||||
                resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            else:
 | 
			
		||||
                resized_image =img
 | 
			
		||||
        elif fit_into_canvas == 1:
 | 
			
		||||
            white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 
 | 
			
		||||
            scale = min(budget_height / height, budget_width / width)
 | 
			
		||||
            new_height = int(height * scale)
 | 
			
		||||
@ -240,152 +301,112 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
 | 
			
		||||
            resized_image = Image.fromarray(white_canvas)  
 | 
			
		||||
        else:
 | 
			
		||||
            scale = (budget_height * budget_width / (height * width))**(1/2)
 | 
			
		||||
            new_height = int( round(height * scale / 16) * 16)
 | 
			
		||||
            new_width = int( round(width * scale / 16) * 16)
 | 
			
		||||
            new_height = int( round(height * scale / block_size) * block_size)
 | 
			
		||||
            new_width = int( round(width * scale / block_size) * block_size)
 | 
			
		||||
            resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
        if rm_background  and not (ignore_first and i == 0) :
 | 
			
		||||
        if rm_background  and not (any_background_ref and i==0 or any_background_ref == 2) :
 | 
			
		||||
            # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
            resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
        output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
 | 
			
		||||
    return output_list
 | 
			
		||||
        if return_tensor:
 | 
			
		||||
            output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) 
 | 
			
		||||
        else:
 | 
			
		||||
            output_list.append(resized_image) 
 | 
			
		||||
        output_mask_list.append(resized_mask)
 | 
			
		||||
    return output_list, output_mask_list
 | 
			
		||||
 | 
			
		||||
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
 | 
			
		||||
    from shared.utils.utils import save_image
 | 
			
		||||
    inpaint_color = canvas_tf_bg / 127.5 - 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def str2bool(v):
 | 
			
		||||
    """
 | 
			
		||||
    Convert a string to a boolean.
 | 
			
		||||
 | 
			
		||||
    Supported true values: 'yes', 'true', 't', 'y', '1'
 | 
			
		||||
    Supported false values: 'no', 'false', 'f', 'n', '0'
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        v (str): String to convert.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        bool: Converted boolean value.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        argparse.ArgumentTypeError: If the value cannot be converted to boolean.
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(v, bool):
 | 
			
		||||
        return v
 | 
			
		||||
    v_lower = v.lower()
 | 
			
		||||
    if v_lower in ('yes', 'true', 't', 'y', '1'):
 | 
			
		||||
        return True
 | 
			
		||||
    elif v_lower in ('no', 'false', 'f', 'n', '0'):
 | 
			
		||||
        return False
 | 
			
		||||
    ref_width, ref_height = ref_img.size
 | 
			
		||||
    if (ref_height, ref_width) == image_size and outpainting_dims  == None:
 | 
			
		||||
        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
        canvas = torch.zeros_like(ref_img[:1]) if return_mask else None
 | 
			
		||||
    else:
 | 
			
		||||
        raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import sys, time
 | 
			
		||||
 | 
			
		||||
# Global variables to track download progress
 | 
			
		||||
_start_time = None
 | 
			
		||||
_last_time = None
 | 
			
		||||
_last_downloaded = 0
 | 
			
		||||
_speed_history = []
 | 
			
		||||
_update_interval = 0.5  # Update speed every 0.5 seconds
 | 
			
		||||
 | 
			
		||||
def progress_hook(block_num, block_size, total_size, filename=None):
 | 
			
		||||
    """
 | 
			
		||||
    Simple progress bar hook for urlretrieve
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        block_num: Number of blocks downloaded so far
 | 
			
		||||
        block_size: Size of each block in bytes
 | 
			
		||||
        total_size: Total size of the file in bytes
 | 
			
		||||
        filename: Name of the file being downloaded (optional)
 | 
			
		||||
    """
 | 
			
		||||
    global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
 | 
			
		||||
    
 | 
			
		||||
    current_time = time.time()
 | 
			
		||||
    downloaded = block_num * block_size
 | 
			
		||||
    
 | 
			
		||||
    # Initialize timing on first call
 | 
			
		||||
    if _start_time is None or block_num == 0:
 | 
			
		||||
        _start_time = current_time
 | 
			
		||||
        _last_time = current_time
 | 
			
		||||
        _last_downloaded = 0
 | 
			
		||||
        _speed_history = []
 | 
			
		||||
    
 | 
			
		||||
    # Calculate download speed only at specified intervals
 | 
			
		||||
    speed = 0
 | 
			
		||||
    if current_time - _last_time >= _update_interval:
 | 
			
		||||
        if _last_time > 0:
 | 
			
		||||
            current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
 | 
			
		||||
            _speed_history.append(current_speed)
 | 
			
		||||
            # Keep only last 5 speed measurements for smoothing
 | 
			
		||||
            if len(_speed_history) > 5:
 | 
			
		||||
                _speed_history.pop(0)
 | 
			
		||||
            # Average the recent speeds for smoother display
 | 
			
		||||
            speed = sum(_speed_history) / len(_speed_history)
 | 
			
		||||
        
 | 
			
		||||
        _last_time = current_time
 | 
			
		||||
        _last_downloaded = downloaded
 | 
			
		||||
    elif _speed_history:
 | 
			
		||||
        # Use the last calculated average speed
 | 
			
		||||
        speed = sum(_speed_history) / len(_speed_history)
 | 
			
		||||
    # Format file sizes and speed
 | 
			
		||||
    def format_bytes(bytes_val):
 | 
			
		||||
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
			
		||||
            if bytes_val < 1024:
 | 
			
		||||
                return f"{bytes_val:.1f}{unit}"
 | 
			
		||||
            bytes_val /= 1024
 | 
			
		||||
        return f"{bytes_val:.1f}TB"
 | 
			
		||||
    
 | 
			
		||||
    file_display = filename if filename else "Unknown file"
 | 
			
		||||
    
 | 
			
		||||
    if total_size <= 0:
 | 
			
		||||
        # If total size is unknown, show downloaded bytes
 | 
			
		||||
        speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
			
		||||
        line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
 | 
			
		||||
        # Clear any trailing characters by padding with spaces
 | 
			
		||||
        sys.stdout.write(line.ljust(80))
 | 
			
		||||
        sys.stdout.flush()
 | 
			
		||||
        return
 | 
			
		||||
    
 | 
			
		||||
    downloaded = block_num * block_size
 | 
			
		||||
    percent = min(100, (downloaded / total_size) * 100)
 | 
			
		||||
    
 | 
			
		||||
    # Create progress bar (40 characters wide to leave room for other info)
 | 
			
		||||
    bar_length = 40
 | 
			
		||||
    filled = int(bar_length * percent / 100)
 | 
			
		||||
    bar = '█' * filled + '░' * (bar_length - filled)
 | 
			
		||||
    
 | 
			
		||||
    # Format file sizes and speed
 | 
			
		||||
    def format_bytes(bytes_val):
 | 
			
		||||
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
			
		||||
            if bytes_val < 1024:
 | 
			
		||||
                return f"{bytes_val:.1f}{unit}"
 | 
			
		||||
            bytes_val /= 1024
 | 
			
		||||
        return f"{bytes_val:.1f}TB"
 | 
			
		||||
    
 | 
			
		||||
    speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
			
		||||
    
 | 
			
		||||
    # Display progress with filename first
 | 
			
		||||
    line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
 | 
			
		||||
    # Clear any trailing characters by padding with spaces
 | 
			
		||||
    sys.stdout.write(line.ljust(100))
 | 
			
		||||
    sys.stdout.flush()
 | 
			
		||||
    
 | 
			
		||||
    # Print newline when complete
 | 
			
		||||
    if percent >= 100:
 | 
			
		||||
        print()
 | 
			
		||||
 | 
			
		||||
# Wrapper function to include filename in progress hook
 | 
			
		||||
def create_progress_hook(filename):
 | 
			
		||||
    """Creates a progress hook with the filename included"""
 | 
			
		||||
    global _start_time, _last_time, _last_downloaded, _speed_history
 | 
			
		||||
    # Reset timing variables for new download
 | 
			
		||||
    _start_time = None
 | 
			
		||||
    _last_time = None
 | 
			
		||||
    _last_downloaded = 0
 | 
			
		||||
    _speed_history = []
 | 
			
		||||
    
 | 
			
		||||
    def hook(block_num, block_size, total_size):
 | 
			
		||||
        return progress_hook(block_num, block_size, total_size, filename)
 | 
			
		||||
    return hook
 | 
			
		||||
        if outpainting_dims != None:
 | 
			
		||||
            final_height, final_width = image_size
 | 
			
		||||
            canvas_height, canvas_width, margin_top, margin_left =   get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 1)        
 | 
			
		||||
        else:
 | 
			
		||||
            canvas_height, canvas_width = image_size
 | 
			
		||||
        if full_frame:
 | 
			
		||||
            new_height = canvas_height
 | 
			
		||||
            new_width = canvas_width
 | 
			
		||||
            top = left = 0 
 | 
			
		||||
        else:
 | 
			
		||||
            # if fill_max  and (canvas_height - new_height) < 16:
 | 
			
		||||
            #     new_height = canvas_height
 | 
			
		||||
            # if fill_max  and (canvas_width - new_width) < 16:
 | 
			
		||||
            #     new_width = canvas_width
 | 
			
		||||
            scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
            new_height = int(ref_height * scale)
 | 
			
		||||
            new_width = int(ref_width * scale)
 | 
			
		||||
            top = (canvas_height - new_height) // 2
 | 
			
		||||
            left = (canvas_width - new_width) // 2
 | 
			
		||||
        ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
        if outpainting_dims != None:
 | 
			
		||||
            canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
            canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img 
 | 
			
		||||
        else:
 | 
			
		||||
            canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
            canvas[:, :, top:top + new_height, left:left + new_width] = ref_img 
 | 
			
		||||
        ref_img = canvas
 | 
			
		||||
        canvas = None
 | 
			
		||||
        if return_mask:
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
 | 
			
		||||
            else:
 | 
			
		||||
                canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                canvas[:, :, top:top + new_height, left:left + new_width] = 0
 | 
			
		||||
            canvas = canvas.to(device)
 | 
			
		||||
    if return_image:
 | 
			
		||||
        return convert_tensor_to_image(ref_img), canvas
 | 
			
		||||
 | 
			
		||||
    return ref_img.to(device), canvas
 | 
			
		||||
 | 
			
		||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [],  inject_frames = [], outpainting_dims = None, device ="cpu"):
 | 
			
		||||
    src_videos, src_masks = [], []
 | 
			
		||||
    inpaint_color_compressed = guide_inpaint_color/127.5 - 1
 | 
			
		||||
    prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
 | 
			
		||||
    for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
 | 
			
		||||
        src_video, src_mask = cur_video_guide, cur_video_mask
 | 
			
		||||
        if pre_video_guide is not None:
 | 
			
		||||
            src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
 | 
			
		||||
            if any_mask:
 | 
			
		||||
                src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1)
 | 
			
		||||
 | 
			
		||||
        if any_guide_padding:
 | 
			
		||||
            if src_video is None:
 | 
			
		||||
                src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
 | 
			
		||||
            elif src_video.shape[1] < current_video_length:
 | 
			
		||||
                src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:]  ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
        elif src_video is not None:
 | 
			
		||||
            new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 
 | 
			
		||||
            src_video = src_video[:, :new_num_frames]
 | 
			
		||||
 | 
			
		||||
        if any_mask and src_video is not None:
 | 
			
		||||
            if src_mask is None:                   
 | 
			
		||||
                src_mask = torch.ones_like(src_video[:1])
 | 
			
		||||
            elif src_mask.shape[1] < src_video.shape[1]:
 | 
			
		||||
                src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:]  ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
            else:
 | 
			
		||||
                src_mask = src_mask[:, :src_video.shape[1]]                                        
 | 
			
		||||
 | 
			
		||||
        if src_video is not None :
 | 
			
		||||
            for k, keep in enumerate(keep_video_guide_frames):
 | 
			
		||||
                if not keep:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[:, pos:pos+1] = inpaint_color_compressed
 | 
			
		||||
                    if any_mask: src_mask[:, pos:pos+1] = 1
 | 
			
		||||
 | 
			
		||||
            for k, frame in enumerate(inject_frames):
 | 
			
		||||
                if frame != None:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
 | 
			
		||||
                    if any_mask: src_mask[:, pos:pos+1] = msk
 | 
			
		||||
        src_videos.append(src_video)
 | 
			
		||||
        src_masks.append(src_mask)
 | 
			
		||||
    return src_videos, src_masks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user