diff --git a/README.md b/README.md index dff2873..d5bb960 100644 --- a/README.md +++ b/README.md @@ -19,222 +19,96 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep -## 🔥 Latest Updates : -### August 29 2025: WanGP v8.2 - Here Goes Your Weekend +## 🔥 Latest Updates : +### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release -- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) +So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. -- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn +In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion. -- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis +With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise) -### August 24 2025: WanGP v8.1 - the RAM Liberator +- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. -- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP -- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ -If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. -- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps +Also because I wanted to spoil you: +- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version. -### August 21 2025: WanGP v8.01 - the killer of seven +- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask. -- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. -- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. -- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work -- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading -- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. -*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* -### August 12 2025: WanGP v7.7777 - Lucky Day(s) +*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora +*Update 8.72*: shadow drop of Qwen Edit Plus +*Update 8.73*: Qwen Preview & InfiniteTalk Start image +*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video -This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! +### September 15 2025: WanGP v8.6 - Attack of the Clones -Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. +- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) -Support: -- Video: x264, x264 lossless, x265 -- Images: jpeg, png, webp, wbp lossless -Generation Settings are stored in each of the above regardless of the format (that was the hard part). +- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : -Also you can now choose different output directories for images and videos. +For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on .... +If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. -unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. -*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* -*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* -- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) -- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) +- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. +- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video. -### August 10 2025: WanGP v7.76 - Faster than the VAE ... -We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... -Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. +### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ? -*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* +I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof. -### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 -Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. +Otherwise in the news: +- **Cropped Input Image Prompts**: as quite often most *Image Prompts* provided (*Start Image, Input Video, Reference Image, Control Video, ...*) rarely matched your requested *Output Resolution*. In that case I used the resolution you gave either as a *Pixels Budget* or as an *Outer Canvas* for the Generated Video. However in some occasion you really want the requested Output Resolution and nothing else. Besides some models deliver much better Generations if you stick to one of their supported resolutions. In order to address this need I have added a new Output Resolution choice in the *Configuration Tab*: **Dimensions Correspond to the Ouput Weight & Height as the Prompt Images will be Cropped to fit Exactly these dimensins**. In short if needed the *Input Prompt Images* will be cropped (centered cropped for the moment). You will see this can make quite a difference for some models -Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! +- *Qwen Edit* has now a new sub Tab called **Inpainting**, that lets you target with a brush which part of the *Image Prompt* you want to modify. This is quite convenient if you find that Qwen Edit modifies usually too many things. Of course, as there are more constraints for Qwen Edit don't be surprised if sometime it will return the original image unchanged. A piece of advise: describe in your *Text Prompt* where (for instance *left to the man*, *top*, ...) the parts that you want to modify are located. -### August 8 2025: WanGP v7.73 - Qwen Rebirth -Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. +The mask inpainting is fully compatible with *Matanyone Mask generator*: generate first an *Image Mask* with Matanyone, transfer it to the current Image Generator and modify the mask with the *Paint Brush*. Talking about matanyone I have fixed a bug that caused a mask degradation with long videos (now WanGP Matanyone is as good as the original app and still requires 3 times less VRAM) -As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) +- This **Inpainting Mask Editor** has been added also to *Vace Image Mode*. Vace is probably still one of best Image Editor today. Here is a very simple & efficient workflow that do marvels with Vace: +Select *Vace Cocktail > Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*. -Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. +Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ... -*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* +For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*" +**update 8.55**: Flux Festival +- **Inpainting Mode** also added for *Flux Kontext* +- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available +- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768 -### August 6 2025: WanGP v7.71 - Picky, picky +Good luck with finding your way through all the Flux models names ! -This release comes with two new models : -- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals -- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) +### September 5 2025: WanGP v8.4 - Take me to Outer Space +You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie. -There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. +I have made it easier to do just that with *Qwen Edit* and *Wan*: +- **End Frames can now be combined with Continue a Video** (and not just a Start Frame) +- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window -*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* +You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*) +The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement. -### August 4 2025: WanGP v7.6 - Remuxed +This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**. +So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close) -With this new version you won't have any excuse if there is no sound in your video. -*Continue Video* now works with any video that has already some sound (hint: Multitalk ). +You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens). -Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. +### September 2 2025: WanGP v8.31 - At last the pain stops -As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. +- This single new feature should give you the strength to face all the potential bugs of this new release: +**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.** -For instance: -- first video part: use Multitalk with two people speaking -- second video part: you apply your own soundtrack which will gently follow the multitalk conversation -- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio +- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). -To multiply the combinations I have also implemented *Continue Video* with the various image2video models. +- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ... -Also: -- End Frame support added for LTX Video models -- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides -- Flux Krea Dev support -### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 -Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... +*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs* -Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. - -I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... - -And this time I really removed Vace Cocktail Light which gave a blurry vision. - -### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview -Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. - -So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. - -However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! - -Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation - -Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... - -7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. - -### July 27 2025: WanGP v7.3 : Interlude -While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. - -### July 26 2025: WanGP v7.2 : Ode to Vace -I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. - -Here are some new Vace improvements: -- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. -- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). -- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. -- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. - -Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. - - -### July 21 2025: WanGP v7.12 -- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. - -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - -- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - -- LTX IC-Lora support: these are special Loras that consumes a conditional image or video -Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. - -- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) - -- Easier way to select video resolution - - -### July 15 2025: WanGP v7.0 is an AI Powered Photoshop -This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : -- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB -- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer -- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ... -- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation - -And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\ -As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\ -This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization. - -WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras. - -Also in the news: -- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected. -- *Film Grain* post processing to add a vintage look at your video -- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete -- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated. - - -### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me -Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. - -So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. - -Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. - -The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. - -### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : -**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). - -Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. - -And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. - -As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. - -But wait, there is more: -- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) -- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. - -Also, of interest too: -- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos -- Force the generated video fps to your liking, works wery well with Vace when using a Control Video -- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) - -### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: -- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations -- In one click use the newly generated video as a Control Video or Source Video to be continued -- Manage multiple settings for the same model and switch between them using a dropdown box -- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open -- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) - -Taking care of your life is not enough, you want new stuff to play with ? -- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio -- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best -- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps -- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage -- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video -- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer -- Choice of Wan Samplers / Schedulers -- More Lora formats support - -**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** See full changelog: **[Changelog](docs/CHANGELOG.md)** diff --git a/configs/animate.json b/configs/animate.json new file mode 100644 index 0000000..7e98ca9 --- /dev/null +++ b/configs/animate.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "motion_encoder_dim": 512 +} \ No newline at end of file diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json new file mode 100644 index 0000000..4983ced --- /dev/null +++ b/configs/lucy_edit.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 96, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/defaults/animate.json b/defaults/animate.json new file mode 100644 index 0000000..bf45f4a --- /dev/null +++ b/defaults/animate.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Wan2.2 Animate", + "architecture": "animate", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" + ], + "preload_URLs" : + [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" + ], + "group": "wan2_2" + } +} \ No newline at end of file diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index 8945918..20b6bc4 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -7,8 +7,6 @@ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-kontext" }, "prompt": "add a hat", diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json new file mode 100644 index 0000000..57164bb --- /dev/null +++ b/defaults/flux_dev_umo.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Flux 1 Dev UMO 12B", + "architecture": "flux", + "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", + "URLs": "flux", + "flux-model": "flux-dev-umo", + "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"], + "resolutions": [ ["1024x1024 (1:1)", "1024x1024"], + ["768x1024 (3:4)", "768x1024"], + ["1024x768 (4:3)", "1024x768"], + ["512x1024 (1:2)", "512x1024"], + ["1024x512 (2:1)", "1024x512"], + ["768x768 (1:1)", "768x768"], + ["768x512 (3:2)", "768x512"], + ["512x768 (2:3)", "512x768"]] + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "768x768", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index ab5ac54..806dd7e 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -2,15 +2,13 @@ "model": { "name": "Flux 1 Dev USO 12B", "architecture": "flux", - "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).", + "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).", "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-uso" }, - "prompt": "add a hat", + "prompt": "the man is wearing a hat", "embedded_guidance_scale": 4, "resolution": "1024x1024", "batch_size": 1 diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json new file mode 100644 index 0000000..59f07c6 --- /dev/null +++ b/defaults/flux_srpo.json @@ -0,0 +1,15 @@ +{ + "model": { + "name": "Flux 1 SRPO Dev 12B", + "architecture": "flux", + "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors" + ], + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json new file mode 100644 index 0000000..ddfe50d --- /dev/null +++ b/defaults/flux_srpo_uso.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Flux 1 SRPO USO 12B", + "architecture": "flux", + "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", + "modules": [ "flux_dev_uso"], + "URLs": "flux_srpo", + "loras": "flux_dev_uso", + "flux-model": "flux-dev-uso" + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "1024x1024", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json new file mode 100644 index 0000000..a8f67ad --- /dev/null +++ b/defaults/lucy_edit.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Wan2.2 Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 30, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json new file mode 100644 index 0000000..d5d47c8 --- /dev/null +++ b/defaults/lucy_edit_fastwan.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Wan2.2 FastWan Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", + "URLs": "lucy_edit", + "group": "wan2_2", + "loras": "ti2v_2_2_fastwan" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 5, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json index 2b24c72..04fc573 100644 --- a/defaults/qwen_image_edit_20B.json +++ b/defaults/qwen_image_edit_20B.json @@ -7,11 +7,10 @@ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors" ], + "preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"], "attention": { "<89": "sdpa" - }, - "reference_image": true, - "image_outputs": true + } }, "prompt": "add a hat", "resolution": "1280x720", diff --git a/defaults/qwen_image_edit_plus_20B.json b/defaults/qwen_image_edit_plus_20B.json new file mode 100644 index 0000000..e10deb2 --- /dev/null +++ b/defaults/qwen_image_edit_plus_20B.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Qwen Image Edit Plus 20B", + "architecture": "qwen_image_edit_plus_20B", + "description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors" + ], + "preload_URLs": "qwen_image_edit_20B", + "attention": { + "<89": "sdpa" + } + }, + "prompt": "add a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/standin.json b/defaults/standin.json index 1b5e324..09298e9 100644 --- a/defaults/standin.json +++ b/defaults/standin.json @@ -4,7 +4,7 @@ "name": "Wan2.1 Standin 14B", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]], "architecture" : "standin", - "description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of person face to transfer this person in the Video.", + "description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of a person face to transfer this person in the Video.", "URLs": "t2v" } } \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index 064c2b4..fa69f82 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -7,6 +7,7 @@ "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, + "prompt" : "Put the person into a clown outfit.", "video_length": 121, "guidance_scale": 1, "flow_shift": 3, diff --git a/defaults/vace_fun_14B_2_2.json b/defaults/vace_fun_14B_2_2.json new file mode 100644 index 0000000..8f22d34 --- /dev/null +++ b/defaults/vace_fun_14B_2_2.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun 14B", + "architecture": "vace_14B", + "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 30, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json new file mode 100644 index 0000000..c587abd --- /dev/null +++ b/defaults/vace_fun_14B_cocktail_2_2.json @@ -0,0 +1,28 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun Cocktail 14B", + "architecture": "vace_14B", + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2", + "URLs": "vace_fun_14B_2_2", + "URLs2": "vace_fun_14B_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [ + 1, + 0.2, + 0.5, + 0.5 + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/docs/AMD-INSTALLATION.md b/docs/AMD-INSTALLATION.md new file mode 100644 index 0000000..4f05589 --- /dev/null +++ b/docs/AMD-INSTALLATION.md @@ -0,0 +1,146 @@ +# Installation Guide + +This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs +running under Windows. + +tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair). + +Currently supported (but not necessary tested): + +**gfx110x**: + +* Radeon RX 7600 +* Radeon RX 7700 XT +* Radeon RX 7800 XT +* Radeon RX 7900 GRE +* Radeon RX 7900 XT +* Radeon RX 7900 XTX + +**gfx1151**: + +* Ryzen 7000 series APUs (Phoenix) +* Ryzen Z1 (e.g., handheld devices like the ROG Ally) + +**gfx1201**: + +* Ryzen 8000 series APUs (Strix Point) +* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop + + +## Requirements + +- Python 3.11 (3.12 might work, 3.10 definately will not!) + +## Installation Environment + +This installation uses PyTorch 2.7.0 because that's what currently available in +terms of pre-compiled wheels. + +### Installing Python + +Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test. + +After installing, make sure `python --version` works in your terminal and returns 3.11.x + +If not, you probably need to fix your PATH. Go to: + +* Windows + Pause/Break +* Advanced System Settings +* Environment Variables +* Edit your `Path` under User Variables + +Example correct entries: + +```cmd +C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\ +``` + +If that doesnt work, scream into a bucket. + +### Installing Git + +Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine. + + +## Install (Windows, using `venv`) + +### Step 1: Download and Set Up Environment + +```cmd +:: Navigate to your desired install directory +cd \your-path-to-wan2gp + +:: Clone the repository +git clone https://github.com/deepbeepmeep/Wan2GP.git +cd Wan2GP + +:: Create virtual environment using Python 3.10.9 +python -m venv wan2gp-env + +:: Activate the virtual environment +wan2gp-env\Scripts\activate +``` + +### Step 2: Install PyTorch + +The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says: + +**Pytorch wheels for gfx110x, gfx1151, and gfx1201** + +Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming. + +Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter. + +```cmd +pip install ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl +``` + +### Step 3: Install Dependencies + +```cmd +:: Install core dependencies +pip install -r requirements.txt +``` + +## Attention Modes + +WanGP supports several attention implementations, only one of which will work for you: + +- **SDPA** (default): Available by default with PyTorch. This uses the built-in aotriton accel library, so is actually pretty fast. + +## Performance Profiles + +Choose a profile based on your hardware: + +- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model +- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement + +## Running Wan2GP + +In future, you will have to do this: + +```cmd +cd \path-to\wan2gp +wan2gp\Scripts\activate.bat +python wgp.py +``` + +For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment) + +## Troubleshooting + +- If you use a HIGH VRAM mode, don't be a fool. Make sure you use VAE Tiled Decoding. + +### Memory Issues + +- Use lower resolution or shorter videos +- Enable quantization (default) +- Use Profile 4 for lower VRAM usage +- Consider using 1.3B models instead of 14B models + +For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 5a89d93..b0eeae3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,20 +1,154 @@ # Changelog ## 🔥 Latest News -### July 21 2025: WanGP v7.1 +### August 29 2025: WanGP v8.21 - Here Goes Your Weekend + +- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) + +- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn + +- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis + +### August 24 2025: WanGP v8.1 - the RAM Liberator + +- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP +- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ +If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. +- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps + +### August 21 2025: WanGP v8.01 - the killer of seven + +- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. +- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. +- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work +- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading +- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. + +*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* +### August 12 2025: WanGP v7.7777 - Lucky Day(s) + +This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! + +Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. + +Support: +- Video: x264, x264 lossless, x265 +- Images: jpeg, png, webp, wbp lossless +Generation Settings are stored in each of the above regardless of the format (that was the hard part). + +Also you can now choose different output directories for images and videos. + +unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. +*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* +*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) + + +### August 10 2025: WanGP v7.76 - Faster than the VAE ... +We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... +Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. + +*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* + +### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 +Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. + +Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! + +### August 8 2025: WanGP v7.73 - Qwen Rebirth +Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. + +As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) + +Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. + +*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* + + +### August 6 2025: WanGP v7.71 - Picky, picky + +This release comes with two new models : +- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals +- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) + +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. + +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* + + +### August 4 2025: WanGP v7.6 - Remuxed + +With this new version you won't have any excuse if there is no sound in your video. + +*Continue Video* now works with any video that has already some sound (hint: Multitalk ). + +Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. + +As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. + +For instance: +- first video part: use Multitalk with two people speaking +- second video part: you apply your own soundtrack which will gently follow the multitalk conversation +- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio + +To multiply the combinations I have also implemented *Continue Video* with the various image2video models. + +Also: +- End Frame support added for LTX Video models +- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides +- Flux Krea Dev support + +### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 +Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... + +Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. + +I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... + +And this time I really removed Vace Cocktail Light which gave a blurry vision. + +### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview +Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. + +So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. + +However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! + +Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation + +Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... + +7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. + +### July 27 2025: WanGP v7.3 : Interlude +While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. + +### July 26 2025: WanGP v7.2 : Ode to Vace +I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. + +Here are some new Vace improvements: +- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. +- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). +- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. +- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. + +Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. + + +### July 21 2025: WanGP v7.12 - Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment +- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - LTX IC-Lora support: these are special Loras that consumes a conditional image or video Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. -And Also: -- easier way to select video resolution -- started to optimize Matanyone to reduce VRAM requirements +- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) +- Easier way to select video resolution ### July 15 2025: WanGP v7.0 is an AI Powered Photoshop This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index 9f66422..361f266 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -27,7 +27,7 @@ conda activate wan2gp ### Step 2: Install PyTorch ```shell -# Install PyTorch 2.7.0 with CUDA 12.4 +# Install PyTorch 2.7.0 with CUDA 12.8 pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 ``` diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 162ec4c..471c339 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -13,27 +13,52 @@ class family_handler(): flux_schnell = flux_model == "flux-schnell" flux_chroma = flux_model == "flux-chroma" flux_uso = flux_model == "flux-dev-uso" - model_def_output = { + flux_umo = flux_model == "flux-dev-umo" + flux_kontext = flux_model == "flux-dev-kontext" + + extra_model_def = { "image_outputs" : True, "no_negative_prompt" : not flux_chroma, } if flux_chroma: - model_def_output["guidance_max_phases"] = 1 + extra_model_def["guidance_max_phases"] = 1 elif not flux_schnell: - model_def_output["embedded_guidance"] = True + extra_model_def["embedded_guidance"] = True if flux_uso : - model_def_output["any_image_refs_relative_size"] = True - model_def_output["no_background_removal"] = True - - model_def_output["image_ref_choices"] = { - "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"), - ("Up to two Images are Style Images", "IJ")], - "default": "I", - "letters_filter": "IJ", + extra_model_def["any_image_refs_relative_size"] = True + extra_model_def["no_background_removal"] = True + extra_model_def["image_ref_choices"] = { + "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), + ("Up to two Images are Style Images", "KIJ")], + "default": "KI", + "letters_filter": "KIJ", "label": "Reference Images / Style Images" } + + if flux_kontext: + extra_model_def["inpaint_support"] = True + extra_model_def["image_ref_choices"] = { + "choices": [ + ("None", ""), + ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "KI", + } + extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" + elif flux_umo: + extra_model_def["image_ref_choices"] = { + "choices": [ + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "I", + "visible": False + } - return model_def_output + + extra_model_def["fit_into_canvas_image_refs"] = 0 + + return extra_model_def @staticmethod def query_supported_types(): @@ -82,7 +107,7 @@ class family_handler(): ] @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .flux_main import model_factory flux_model = model_factory( @@ -107,15 +132,38 @@ class family_handler(): pipe["feature_embedder"] = flux_model.feature_embedder return flux_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + flux_model = model_def.get("flux-model", "flux-dev") + flux_uso = flux_model == "flux-dev-uso" + if flux_uso and settings_version < 2.29: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "I" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("I", "KI") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.34: + ui_defaults["denoising_strength"] = 1. + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): flux_model = model_def.get("flux-model", "flux-dev") flux_uso = flux_model == "flux-dev-uso" + flux_umo = flux_model == "flux-dev-umo" + flux_kontext = flux_model == "flux-dev-kontext" ui_defaults.update({ "embedded_guidance": 2.5, - }) - if model_def.get("reference_image", False): - ui_defaults.update({ - "video_prompt_type": "I" if flux_uso else "KI", - }) + }) + + if flux_kontext or flux_uso: + ui_defaults.update({ + "video_prompt_type": "KI", + "denoising_strength": 1., + }) + elif flux_umo: + ui_defaults.update({ + "video_prompt_type": "I", + "remove_background_images_ref": 0, + }) + diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 55a2b91..6746a23 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -9,6 +9,9 @@ from shared.utils.utils import calculate_new_dimensions from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack from .modules.layers import get_linear_split_map from transformers import SiglipVisionModel, SiglipImageProcessor +import torchvision.transforms.functional as TVF +import math +from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image from .util import ( aspect_ratio_to_height_width, @@ -20,6 +23,35 @@ from .util import ( ) from PIL import Image +def preprocess_ref(raw_image: Image.Image, long_size: int = 512): + # 获取原始图像的宽度和高度 + image_w, image_h = raw_image.size + + # 计算长边和短边 + if image_w >= image_h: + new_w = long_size + new_h = int((long_size / image_w) * image_h) + else: + new_h = long_size + new_w = int((long_size / image_h) * image_w) + + # 按新的宽高进行等比例缩放 + raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) + target_w = new_w // 16 * 16 + target_h = new_h // 16 * 16 + + # 计算裁剪的起始坐标以实现中心裁剪 + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + right = left + target_w + bottom = top + target_h + + # 进行中心裁剪 + raw_image = raw_image.crop((left, top, right, bottom)) + + # 转换为 RGB 模式 + raw_image = raw_image.convert("RGB") + return raw_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -64,7 +96,7 @@ class model_factory: # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) - + self.model_def = model_def self.vae = load_ae(self.name, device=torch_device) siglip_processor = siglip_model = feature_embedder = None @@ -106,10 +138,12 @@ class model_factory: def generate( self, seed: int | None = None, - input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, + input_frames= None, + input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -120,7 +154,8 @@ class model_factory: batch_size = 1, video_prompt_type = "", joint_pass = False, - image_refs_relative_size = 100, + image_refs_relative_size = 100, + denoising_strength = 1., **bbargs ): if self._interrupt: @@ -129,11 +164,17 @@ class model_factory: if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] - image_stiching = not self.name in ['flux-dev-uso'] + flux_dev_umo = self.name in ['flux-dev-umo'] + latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo'] + + lock_dimensions= False input_ref_images = [] if input_ref_images is None else input_ref_images[:] + if flux_dev_umo: + ref_long_side = 512 if len(input_ref_images) <= 1 else 320 + input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images] + lock_dimensions = True ref_style_imgs = [] - if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : if "J" in video_prompt_type: @@ -142,33 +183,28 @@ class model_factory: elif len(input_ref_images) > 1 : ref_style_imgs = input_ref_images[-1:] input_ref_images = input_ref_images[:-1] - if image_stiching: + + if latent_stiching: + # latents stiching with resize + if not lock_dimensions : + for i in range(len(input_ref_images)): + w, h = input_ref_images[i].size + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0) + input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + else: # image stiching method stiched = input_ref_images[0] - if "K" in video_prompt_type : - w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - else: - first_ref = 0 - if "K" in video_prompt_type: - # image latents tiling method - w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) - first_ref = 1 - - for i in range(first_ref,len(input_ref_images)): - w, h = input_ref_images[i].size - image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + elif input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + - if flux_dev_uso : + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( ae=self.vae, img_cond_list=input_ref_images, @@ -187,6 +223,7 @@ class model_factory: bs=batch_size, seed=seed, device=device, + img_mask=image_mask, ) inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) @@ -208,13 +245,19 @@ class model_factory: return unpack(x.float(), height, width) # denoise initial noise - x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass) + x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) + if image_mask is not None: + from shared.utils.utils import convert_image_to_tensor + img_msk_rebuilt = inp["img_msk_rebuilt"] + img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) + x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + x = x.clamp(-1, 1) x = x.transpose(0, 1) return x diff --git a/models/flux/model.py b/models/flux/model.py index c4642d0..c5f7a24 100644 --- a/models/flux/model.py +++ b/models/flux/model.py @@ -190,6 +190,21 @@ class Flux(nn.Module): v = swap_scale_shift(v) k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1") new_sd[k] = v + # elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."): + # for k,v in sd.items(): + # if "double" in k: + # k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_") + # k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_") + # k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_") + # k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_") + # else: + # k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_") + # k = k.replace(".processor.proj_lora.", ".linear2.lora_") + + # k = "diffusion_model." + k + # new_sd[k] = v + # from mmgp import safetensors2 + # safetensors2.torch_write_file(new_sd, "fff.safetensors") else: new_sd = sd return new_sd diff --git a/models/flux/sampling.py b/models/flux/sampling.py index 5534e9f..939543c 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -138,10 +138,12 @@ def prepare_kontext( target_width: int | None = None, target_height: int | None = None, bs: int = 1, - + img_mask = None, ) -> tuple[dict[str, Tensor], int, int]: # load and encode the conditioning image + res_match_output = img_mask is not None + img_cond_seq = None img_cond_seq_ids = None if img_cond_list == None: img_cond_list = [] @@ -150,10 +152,11 @@ def prepare_kontext( for cond_no, img_cond in enumerate(img_cond_list): width, height = img_cond.size aspect_ratio = width / height - - # Kontext is trained on specific resolutions, using one of them is recommended - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) - + if res_match_output: + width, height = target_width, target_height + else: + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) width = 2 * int(width / 16) height = 2 * int(height / 16) @@ -194,6 +197,19 @@ def prepare_kontext( "img_cond_seq": img_cond_seq, "img_cond_seq_ids": img_cond_seq_ids, } + if img_mask is not None: + from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image + # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + return_dict.update({ + "img_msk_latents": image_mask_latents, + "img_msk_rebuilt": image_mask_rebuilt, + }) + img = get_noise( bs, target_height, @@ -265,6 +281,9 @@ def denoise( loras_slists=None, unpack_latent = None, joint_pass= False, + img_msk_latents = None, + img_msk_rebuilt = None, + denoising_strength = 1, ): kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids} @@ -272,19 +291,38 @@ def denoise( if callback != None: callback(-1, None, True) + original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() + original_timesteps = timesteps + morph, first_step = False, 0 + if img_msk_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step] + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = latents.to(img) + latents = None + timesteps = timesteps[first_step:] + + updated_num_steps= len(timesteps) -1 if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(model, loras_slists, updated_num_steps) + update_loras_slists(model, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - offload.set_step_no_for_lora(model, i) + offload.set_step_no_for_lora(model, first_step + i) if pipeline._interrupt: return None + if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t_curr/1000 + img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) img_input = img img_input_ids = img_ids @@ -334,6 +372,14 @@ def denoise( pred = neg_pred + real_guidance_scale * (pred - neg_pred) img += (t_prev - t_curr) * pred + + if img_msk_latents is not None: + latent_noise_factor = t_prev + # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = noisy_image * (1-img_msk_latents) + img_msk_latents * img + noisy_image = None + if callback is not None: preview = unpack_latent(img).transpose(0,1) callback(i, preview, False) diff --git a/models/flux/util.py b/models/flux/util.py index 0f96103..af75f62 100644 --- a/models/flux/util.py +++ b/models/flux/util.py @@ -640,6 +640,38 @@ configs = { shift_factor=0.1159, ), ), + "flux-dev-umo": ModelSpec( + repo_id="", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + eso= True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), } diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index a38a7bd..aa6c3b3 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -861,16 +861,11 @@ class HunyuanVideoSampler(Inference): freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) else: if self.avatar: - w, h = input_ref_images.size - target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) - if target_width != w or target_height != h: - input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) - concat_dict = {'mode': 'timecat', 'bias': -1} freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: if input_frames != None: - target_height, target_width = input_frames.shape[-3:-1] + target_height, target_width = input_frames.shape[-2:] elif input_video != None: target_height, target_width = input_video.shape[-2:] @@ -899,9 +894,10 @@ class HunyuanVideoSampler(Inference): pixel_value_bg = input_video.unsqueeze(0) pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) if input_frames != None: - pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float() + # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() + pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0) if input_video != None: pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) @@ -913,10 +909,11 @@ class HunyuanVideoSampler(Inference): if pixel_value_bg.shape[2] < frame_num: padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) - pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() - pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) + pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1 mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() bg_latents = torch.cat([bg_latents, mask_latents], dim=1) bg_latents.mul_(self.vae.config.scaling_factor) diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index d95bd7e..ebe07d2 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -51,11 +51,38 @@ class family_handler(): extra_model_def["tea_cache"] = True extra_model_def["mag_cache"] = True - if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True + if base_model_type in ["hunyuan_custom_edit"]: + extra_model_def["guide_preprocessing"] = { + "selection": ["MV", "PV"], + } - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: + extra_model_def["mask_preprocessing"] = { + "selection": ["A", "NA"], + "default" : "NA" + } + + if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Reference Image", "I")], + "letters_filter":"I", + "visible": False, + } + + if base_model_type in ["hunyuan_avatar"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Start Image", "KI")], + "letters_filter":"KI", + "visible": False, + } + extra_model_def["no_background_removal"] = True + + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: extra_model_def["one_image_ref_needed"] = True + + if base_model_type in ["hunyuan_i2v"]: + extra_model_def["image_prompt_types_allowed"] = "S" + return extra_model_def @staticmethod @@ -102,7 +129,7 @@ class family_handler(): } @staticmethod - def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .hunyuan import HunyuanVideoSampler from mmgp import offload @@ -137,6 +164,24 @@ class family_handler(): return hunyuan_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if settings_version<2.33: + if base_model_type in ["hunyuan_custom_edit"]: + video_prompt_type= ui_defaults["video_prompt_type"] + if "P" in video_prompt_type and "M" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("M","") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.36: + if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]: + audio_prompt_type= ui_defaults["audio_prompt_type"] + if "A" not in audio_prompt_type: + audio_prompt_type += "A" + ui_defaults["audio_prompt_type"] = audio_prompt_type + + + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults["embedded_guidance_scale"]= 6.0 @@ -158,6 +203,7 @@ class family_handler(): "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "I", + "audio_prompt_type": "A", }) elif base_model_type in ["hunyuan_custom_edit"]: ui_defaults.update({ @@ -174,4 +220,5 @@ class family_handler(): "skip_steps_start_step_perc": 25, "video_length": 129, "video_prompt_type": "KI", + "audio_prompt_type": "A", }) diff --git a/models/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py index 02a733e..c263997 100644 --- a/models/hyvideo/modules/utils.py +++ b/models/hyvideo/modules/utils.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import ( ) -@lru_cache +# @lru_cache def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False): block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile) return block_mask diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py index e71ac4f..080860c 100644 --- a/models/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -300,9 +300,6 @@ class LTXV: prefix_size, height, width = input_video.shape[-3:] else: if image_start != None: - frame_width, frame_height = image_start.size - if fit_into_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_start_frames.append(0) conditioning_control_frames.append(False) @@ -479,14 +476,14 @@ class LTXV: images = images.sub_(0.5).mul_(2).squeeze(0) return images - def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs): + def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs): map = { "P" : "pose", "D" : "depth", "E" : "canny", } loras = [] - preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs") + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") lora_file_name = "" for letter, signature in map.items(): if letter in video_prompt_type: diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index d35bcd4..8c322e1 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -24,7 +24,19 @@ class family_handler(): extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 8 extra_model_def["sliding_window"] = True + extra_model_def["image_prompt_types_allowed"] = "TSEV" + extra_model_def["guide_preprocessing"] = { + "selection": ["", "PV", "DV", "EV", "V"], + "labels" : { "V": "Use LTXV raw format"} + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A", "NA", "XA", "XNA"], + } + + extra_model_def["extra_control_frames"] = 1 + extra_model_def["dont_cat_preguide"]= True return extra_model_def @staticmethod @@ -64,7 +76,7 @@ class family_handler(): @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .ltxv import LTXV ltxv_model = LTXV( diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 07bdbd4..02fd473 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mmgp import offload import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -28,7 +27,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image XLA_AVAILABLE = False @@ -201,7 +200,8 @@ class QwenImagePipeline(): #DiffusionPipeline self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 1024 if processor is not None: - self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode_start_idx = 64 else: self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" @@ -233,6 +233,21 @@ class QwenImagePipeline(): #DiffusionPipeline txt = [template.format(e) for e in prompt] if self.processor is not None and image is not None: + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + model_inputs = self.processor( text=txt, images=image, @@ -387,7 +402,8 @@ class QwenImagePipeline(): #DiffusionPipeline return latent_image_ids.to(device=device, dtype=dtype) @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): + def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -464,7 +480,7 @@ class QwenImagePipeline(): #DiffusionPipeline def prepare_latents( self, - image, + images, batch_size, num_channels_latents, height, @@ -479,30 +495,33 @@ class QwenImagePipeline(): #DiffusionPipeline height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, 1, num_channels_latents, height, width) + shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None - if image is not None: - image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) + if images is not None and len(images ) > 0: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) + image_latents = self._pack_latents(image_latents) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -511,7 +530,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents) else: latents = latents.to(device=device, dtype=dtype) @@ -563,10 +582,15 @@ class QwenImagePipeline(): #DiffusionPipeline callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, image = None, + image_mask = None, + denoising_strength = 0, callback=None, pipeline=None, loras_slists=None, joint_pass= True, + lora_inpaint = False, + outpainting_dims = None, + qwen_edit_plus = False, ): r""" Function invoked when calling the pipeline for generation. @@ -682,33 +706,54 @@ class QwenImagePipeline(): #DiffusionPipeline batch_size = prompt_embeds.shape[0] device = "cuda" - prompt_image = None - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - image = image[0] if isinstance(image, list) else image - image_height, image_width = self.image_processor.get_default_height_width(image) - aspect_ratio = image_width / image_height - if False : - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS - ) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - ref_height, ref_width = 1568, 672 - if height * width < ref_height * ref_width: ref_height , ref_width = height , width - if image_height * image_width > ref_height * ref_width: - image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + condition_images = [] + vae_image_sizes = [] + vae_images = [] + image_mask_latents = None + ref_size = 1024 + ref_text_encoder_size = 384 if qwen_edit_plus else 1024 + if image is not None: + if not isinstance(image, list): image = [image] + if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width)) + for ref_no, img in enumerate(image): + image_width, image_height = img.size + any_mask = ref_no == 0 and image_mask is not None + if (image_height * image_width > ref_size * ref_size) and not any_mask: + vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of) + else: + vae_height, vae_width = image_height, image_width + vae_width = vae_width // multiple_of * multiple_of + vae_height = vae_height // multiple_of * multiple_of + vae_image_sizes.append((vae_width, vae_height)) + condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of) + condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) ) + if img.size != (vae_width, vae_height): + img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS) + if any_mask : + if lora_inpaint: + image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1] + img = convert_image_to_tensor(img) + green = torch.tensor([-1.0, 1.0, -1.0]).to(img) + green_image = green[:, None, None] .expand_as(img) + img = torch.where(image_mask_rebuilt > 0, green_image, img) + img = convert_tensor_to_image(img) + else: + image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) + # img.save("nnn.png") + vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) ) - image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) - prompt_image = image - image = self.image_processor.preprocess(image, image_height, image_width) - image = image.unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -718,7 +763,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -734,7 +779,7 @@ class QwenImagePipeline(): #DiffusionPipeline # 4. Prepare latent variables num_channels_latents = self.transformer.in_channels // 4 latents, image_latents = self.prepare_latents( - image, + vae_images, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -744,11 +789,18 @@ class QwenImagePipeline(): #DiffusionPipeline generator, latents, ) + original_image_latents = None if image_latents is None else image_latents.clone() + if image is not None: img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), - (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] ] * batch_size else: @@ -773,7 +825,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - + original_timesteps = timesteps # handle guidance if self.transformer.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) @@ -788,56 +840,80 @@ class QwenImagePipeline(): #DiffusionPipeline negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) - + morph, first_step = False, 0 + lanpaint_proc = None + if image_mask_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step]/1000 + # latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + timesteps = timesteps[first_step:] + self.scheduler.timesteps = timesteps + self.scheduler.sigmas= self.scheduler.sigmas[first_step:] + # from shared.inpainting.lanpaint import LanPaint + # lanpaint_proc = LanPaint() # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(self.transformer, loras_slists, updated_num_steps) + update_loras_slists(self.transformer, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) + for i, t in enumerate(timesteps): + offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue - self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) + if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t/1000 + latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor - if do_true_cfg and joint_pass: - noise_pred, neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - ) - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - )[0] - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] + + latents_dtype = latents.dtype + + # latent_model_input = latents + def denoise(latent_model_input, true_cfg_scale): + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + do_true_cfg = true_cfg_scale > 1 + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, #!!!! + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + neg_noise_pred = None + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( @@ -851,20 +927,43 @@ class QwenImagePipeline(): #DiffusionPipeline attention_kwargs=self.attention_kwargs, **kwargs )[0] - if neg_noise_pred == None: return None + if neg_noise_pred == None: return None, None neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + return noise_pred, neg_noise_pred + def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): + if do_true_cfg: + comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred) + if comb_pred == None: return None - if do_true_cfg: - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if comb_pred == None: return None + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - neg_noise_pred = None + return noise_pred + + + if lanpaint_proc is not None and i<=3: + latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8) + if latents is None: return None + + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None + noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + neg_noise_pred = None # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = None + + if image_mask_latents is not None: + if lanpaint_proc is not None: + latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents + else: + next_t = timesteps[i+1] if i 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u ** v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v ** u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return 'TensorList: \n' + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and ( + other.numel() > 1 or other.ndim > 1 + ) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and ( + other.numel() == 1 and other.ndim <= 1 + ) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) \ No newline at end of file diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py new file mode 100644 index 0000000..8ddb829 --- /dev/null +++ b/models/wan/animate/face_blocks.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from shared.attention import pay_attention + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + # Size([batches, tokens, heads, head_features]) + qkv_list = [q, k, v] + del q,k,v + attn = pay_attention(qkv_list) + # attn = attention( + # q, + # k, + # v, + # max_seqlen_q=q.shape[1], + # batch_size=q.shape[0], + # ) + + attn = attn.reshape(*attn.shape[:2], -1) + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + # if use_context_parallel: + # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py new file mode 100644 index 0000000..d07f762 --- /dev/null +++ b/models/wan/animate/model_animate.py @@ -0,0 +1,31 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py new file mode 100644 index 0000000..02b0040 --- /dev/null +++ b/models/wan/animate/motion_encoder.py @@ -0,0 +1,308 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype in [torch.bfloat16, torch.float16]: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion_feat = self.enc.enc_motion(img) + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index abe5249..285340d 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -19,7 +19,8 @@ from PIL import Image import torchvision.transforms.functional as TF import torch.nn.functional as F from .distributed.fsdp import shard_model -from .modules.model import WanModel, clear_caches +from .modules.model import WanModel +from mmgp.offload import get_cache, clear_caches from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from .modules.vae2_2 import Wan2_2_VAE @@ -31,9 +32,11 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed from shared.utils.vace_preprocessor import VaceVideoProcessor from shared.utils.basic_flowmatch import FlowMatchScheduler -from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor +from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask +from shared.utils.audio_video import save_video from mmgp import safetensors2 +from shared.utils.audio_video import save_video def optimized_scale(positive_flat, negative_flat): @@ -63,6 +66,7 @@ class WanAny2V: config, checkpoint_dir, model_filename = None, + submodel_no_list = None, model_type = None, model_def = None, base_model_type = None, @@ -91,7 +95,7 @@ class WanAny2V: shard_fn= None) # base_model_type = "i2v2_2" - if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]: + if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]: self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, @@ -100,7 +104,7 @@ class WanAny2V: tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) - if base_model_type in ["ti2v_2_2"]: + if base_model_type in ["ti2v_2_2", "lucy_edit"]: self.vae_stride = (4, 16, 16) vae_checkpoint = "Wan2.2_VAE.safetensors" vae = Wan2_2_VAE @@ -125,75 +129,82 @@ class WanAny2V: forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename - + self.model = self.model2 = None source = model_def.get("source", None) + source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) + module_source2 = model_def.get("module_source2", None) if module_source is not None: - model_filename = [] + model_filename - model_filename[1] = module_source - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - elif source is not None: + self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if module_source2 is not None: + self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if source is not None: self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) - elif self.transformer_switch: - shared_modules= {} - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - shared_modules = None - else: - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - - # self.model = offload.load_model_data(self.model, xmodel_filename ) - # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") + if source2 is not None: + self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - offload.change_dtype(self.model, dtype, True) + if self.model is not None or self.model2 is not None: + from wgp import save_model + from mmgp.safetensors2 import torch_load_file + else: + if self.transformer_switch: + if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]: + raise Exception("Shared and non shared modules at the same time across multipe models is not supported") + + if 0 in submodel_no_list[2:]: + shared_modules= {} + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + shared_modules = None + else: + modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ] + modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ] + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + + else: + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + + + if self.model is not None: + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) + offload.change_dtype(self.model, dtype, True) + self.model.eval().requires_grad_(False) if self.model2 is not None: self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model2, dtype, True) - - # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) - self.model.eval().requires_grad_(False) - if self.model2 is not None: self.model2.eval().requires_grad_(False) + if module_source is not None: - from wgp import save_model - from mmgp.safetensors2 import torch_load_file - filter = list(torch_load_file(module_source)) - save_model(self.model, model_type, dtype, None, is_module=True, filter=filter) - elif not source is None: - from wgp import save_model - save_model(self.model, model_type, dtype, None) + save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1) + if module_source2 is not None: + save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2) + if not source is None: + save_model(self.model, model_type, dtype, None, submodel_no= 1) + if not source2 is None: + save_model(self.model2, model_type, dtype, None, submodel_no= 2) if save_quantized: from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) + if self.model is not None: + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) if self.model2 is not None: save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) self.sample_neg_prompt = config.sample_neg_prompt - if self.model.config.get("vace_in_dim", None) != None: - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=480*832, - max_area=480*832, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=32760, - keep_last=True) - + if hasattr(self.model, "vace_blocks"): self.adapt_vace_model(self.model) if self.model2 is not None: self.adapt_vace_model(self.model2) + if hasattr(self.model, "face_adapter"): + self.adapt_animate_model(self.model) + if self.model2 is not None: self.adapt_animate_model(self.model2) + self.num_timesteps = 1000 self.use_timestep_transform = True def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) + ref_images = [ref_images] * len(frames) if masks is None: latents = self.vae.encode(frames, tile_size = tile_size) @@ -225,11 +236,7 @@ class WanAny2V: return cat_latents def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - + ref_images = [ref_images] * len(masks) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape @@ -257,119 +264,6 @@ class WanAny2V: result_masks.append(mask) return result_masks - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): - from shared.utils.utils import save_image - ref_width, ref_height = ref_img.size - if (ref_height, ref_width) == image_size and outpainting_dims == None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - canvas = torch.zeros_like(ref_img) if return_mask else None - else: - if outpainting_dims != None: - final_height, final_width = image_size - canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) - else: - canvas_height, canvas_width = image_size - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - if fill_max and (canvas_height - new_height) < 16: - new_height = canvas_height - if fill_max and (canvas_width - new_width) < 16: - new_width = canvas_width - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if outpainting_dims != None: - canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img - else: - canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, top:top + new_height, left:left + new_width] = ref_img - ref_img = canvas - canvas = None - if return_mask: - if outpainting_dims != None: - canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 - else: - canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, top:top + new_height, left:left + new_width] = 0 - canvas = canvas.to(device) - return ref_img.to(device), canvas - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): - image_sizes = [] - trim_video_guide = len(keep_video_guide_frames) - def conv_tensor(t, device): - return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames - if sub_src_mask is not None and sub_src_video is not None: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) - # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[i][:, pos:pos+1] = 0 - src_mask[i][:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) - - - self.background_mask = None - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None and not torch.is_tensor(ref_img): - if j==0 and any_background_ref: - if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) - src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) - else: - src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device) - if self.background_mask != None: - self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref - return src_video, src_mask, src_ref_images def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] @@ -380,12 +274,28 @@ class WanAny2V: return torch.cat(ref_vae_latents, dim=1) + def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest') + if nb_frames_unchanged >0: + msk[:, :nb_frames_unchanged] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1,2)[0] + return msk + def generate(self, input_prompt, input_frames= None, + input_frames2= None, input_masks = None, - input_ref_images = None, + input_masks2 = None, + input_ref_images = None, + input_ref_masks = None, + input_faces = None, input_video = None, image_start = None, image_end = None, @@ -453,7 +363,8 @@ class WanAny2V: timesteps.append(0.) timesteps = [torch.tensor([t], device=self.device) for t in timesteps] if self.use_timestep_transform: - timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = torch.tensor(timesteps) sample_scheduler = None elif sample_solver == 'causvid': sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True) @@ -496,6 +407,8 @@ class WanAny2V: text_len = self.model.text_len context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + if input_video is not None: height, width = input_video.shape[-2:] + # NAG_prompt = "static, low resolution, blurry" # context_NAG = self.text_encoder([NAG_prompt], self.device)[0] # context_NAG = context_NAG.to(self.dtype) @@ -516,109 +429,76 @@ class WanAny2V: infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] recam = model_type in ["recam_1.3B"] - ti2v = model_type in ["ti2v_2_2"] + ti2v = model_type in ["ti2v_2_2", "lucy_edit"] + lucy_edit= model_type in ["lucy_edit"] + animate= model_type in ["animate"] start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = None + extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + extended_input_dim = 0 + ref_images_before = False # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False - if image_start is None: - if infinitetalk: - if input_frames is not None: - image_ref = input_frames[:, -1] - if input_video is None: input_video = input_frames[:, -1:] - new_shot = "Q" in video_prompt_type - else: - if pre_video_frame is None: - new_shot = True - else: - if input_ref_images is None: - input_ref_images, new_shot = [pre_video_frame], False - else: - input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type - if input_ref_images is None: raise Exception("Missing Reference Image") - new_shot = new_shot and window_no <= len(input_ref_images) - image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) - if new_shot: - input_video = image_ref.unsqueeze(1) - else: - color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot - _ , preframes_count, height, width = input_video.shape - input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) - if infinitetalk: - image_for_clip = image_ref.to(input_video) - control_pre_frames_count = 1 - control_video = image_for_clip.unsqueeze(1) + if infinitetalk: + new_shot = "Q" in video_prompt_type + if input_frames is not None: + image_ref = input_frames[:, 0] else: - image_for_clip = input_video[:, -1] - control_pre_frames_count = preframes_count - control_video = input_video - lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :] - clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ]) - clip_image = None + if input_ref_images is None: + if pre_video_frame is None: raise Exception("Missing Reference Image") + input_ref_images, new_shot = [pre_video_frame], False + new_shot = new_shot and window_no <= len(input_ref_images) + image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) + if new_shot or input_video is None: + input_video = image_ref.unsqueeze(1) else: - clip_context = None - enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width), - device=self.device, dtype= self.VAE_dtype)], - dim = 1).to(self.device) - color_reference_frame = image_for_clip.unsqueeze(1).clone() + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + _ , preframes_count, height, width = input_video.shape + input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) + if infinitetalk: + image_start = image_ref.to(input_video) + control_pre_frames_count = 1 + control_video = image_start.unsqueeze(1) else: - preframes_count = control_pre_frames_count = 1 - any_end_frame = image_end is not None - add_frames_for_end_image = any_end_frame and model_type == "i2v" - if any_end_frame: - if add_frames_for_end_image: - frame_num +=1 - lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - trim_frames = 1 - - height, width = image_start.shape[1:] + image_start = input_video[:, -1] + control_pre_frames_count = preframes_count + control_video = input_video - lat_h = round( - height // self.vae_stride[1] // - self.patch_size[1] * self.patch_size[1]) - lat_w = round( - width // self.vae_stride[2] // - self.patch_size[2] * self.patch_size[2]) - height = lat_h * self.vae_stride[1] - width = lat_w * self.vae_stride[2] - image_start_frame = image_start.unsqueeze(1).to(self.device) - color_reference_frame = image_start_frame.clone() - if image_end is not None: - img_end_frame = image_end.unsqueeze(1).to(self.device) + color_reference_frame = image_start.unsqueeze(1).clone() - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None + any_end_frame = image_end is not None + add_frames_for_end_image = any_end_frame and model_type == "i2v" + if any_end_frame: + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + if add_frames_for_end_image: + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + trim_frames = 1 - if any_end_frame: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype), - img_end_frame, - ], dim=1).to(self.device) - else: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype) - ], dim=1).to(self.device) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None + if image_end is not None: + img_end_frame = image_end.unsqueeze(1).to(self.device) + clip_image_start, clip_image_end = image_start, image_end + + if any_end_frame: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count-1, height, width), device=self.device, dtype= self.VAE_dtype), + img_end_frame, + ], dim=1).to(self.device) + else: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype) + ], dim=1).to(self.device) + + image_start = image_end = img_end_frame = image_ref = control_video = None msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: @@ -643,42 +523,85 @@ class WanAny2V: if infinitetalk: lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - # if control_pre_frames_count != pre_frames_count: lat_y = input_video = None kwargs.update({ 'y': y}) - if not clip_context is None: - kwargs.update({'clip_fea': clip_context}) - # Recam Master + # Animate + if animate: + pose_pixels = input_frames * input_masks + input_masks = 1. - input_masks + pose_pixels -= input_masks + pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) + input_frames = input_frames * input_masks + if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + if prefix_frames_count > 0: + input_frames[:, :prefix_frames_count] = input_video + input_masks[:, :prefix_frames_count] = 1 + # save_video(pose_pixels, "pose.mp4") + # save_video(input_frames, "input_frames.mp4") + # save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) + msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) + msk = torch.concat([msk_ref, msk_control], dim=1) + image_ref = input_ref_images[0].to(self.device) + clip_image_start = image_ref.squeeze(1) + lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) + y = torch.concat([msk, lat_y]) + kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) + lat_y = msk = msk_control = msk_ref = pose_pixels = None + ref_images_before = True + ref_images_count = 1 + lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1 + + # Clip image + if hasattr(self, "clip") and clip_image_start is not None: + clip_image_size = self.clip.model.image_size + clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size) + clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([clip_image_start[:, None, :, :]]) + clip_image_start = clip_image_end = None + kwargs.update({'clip_fea': clip_context}) + + # Recam Master & Lucy Edit + if recam or lucy_edit: + frame_num, height,width = input_frames.shape[-3:] + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + frame_num = (lat_frames -1) * self.vae_stride[0] + 1 + input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device) + extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + extended_input_dim = 2 if recam else 1 + del input_frames + if recam: - # should be be in fact in input_frames since it is control video not a video to be extended - target_camera = model_mode - height,width = input_video.shape[-2:] - input_video = input_video.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) - del input_video # Process target camera (recammaster) + target_camera = model_mode from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) kwargs['cam_emb'] = cam_emb # Video 2 Video - if denoising_strength < 1. and input_frames != None: + if "G" in video_prompt_type and input_frames != None: height, width = input_frames.shape[-2:] source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) injection_denoising_step = 0 inject_from_start = False if input_frames != None and denoising_strength < 1 : color_reference_frame = input_frames[:, -1:].clone() - if overlapped_latents != None: - overlapped_latents_frames_num = overlapped_latents.shape[2] - overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 + if prefix_frames_count > 0: + overlapped_frames_num = prefix_frames_count + overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 + # overlapped_latents_frames_num = overlapped_latents.shape[2] + # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: overlapped_latents_frames_num = overlapped_frames_num = 0 if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] - injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) + injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) ) latent_keep_frames = [] if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0: inject_from_start = True @@ -694,14 +617,21 @@ class WanAny2V: if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] injection_denoising_step = 0 + if input_masks is not None and not "U" in video_prompt_type: + image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0) + if image_mask_latents.shape[2] !=1: + image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2) + image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device) + # save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) ) + # image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # Phantom if phantom: - input_ref_images_neg = None - if input_ref_images != None: # Phantom Ref images - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 - trim_frames = input_ref_images.shape[1] + lat_input_ref_images_neg = None + if input_ref_images is not None: # Phantom Ref images + lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device) + lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images) + ref_images_count = trim_frames = lat_input_ref_images.shape[1] if ti2v: if input_video is None: @@ -710,28 +640,29 @@ class WanAny2V: height, width = input_video.shape[-2:] source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True + if extended_input_dim > 0: + extended_latents[:, :, :source_latents.shape[2]] = source_latents # Vace if vace : # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] + input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] + input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] + ref_images_before = True z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - color_reference_frame = input_ref_images[0][0].clone() - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) + if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None: + color_reference_frame = input_ref_images[0].clone() + zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): zz0[:, 0:1] = zzbg mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + zz0 = mm0 = zzbg = mmbg = None + z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)] + ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0 context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : @@ -739,17 +670,12 @@ class WanAny2V: extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) if prefix_frames_count > 0: color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w) - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - lat_h, lat_w = target_shape[-2:] - height = self.vae_stride[1] * lat_h - width = self.vae_stride[2] * lat_w - - else: - target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) - - if multitalk and audio_proj != None: + if multitalk: + if audio_proj is None: + audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ] from .multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) @@ -764,9 +690,9 @@ class WanAny2V: expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - if target_camera != None: + if extended_input_dim>=2: shape = list(target_shape[1:]) - shape[0] *= 2 + shape[extended_input_dim-2] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) @@ -777,19 +703,20 @@ class WanAny2V: if standin: from preprocessing.face_preprocessor import FaceProcessor standin_ref_pos = 1 if "K" in video_prompt_type else 0 - if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image") - standin_ref_pos = -1 - image_ref = original_input_ref_images[standin_ref_pos] - image_ref.save("si.png") - # face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2") - face_processor = FaceProcessor() - standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"]) - face_processor = None - gc.collect() - torch.cuda.empty_cache() - standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) - standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) - kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) + if len(original_input_ref_images) < standin_ref_pos + 1: + if "I" in video_prompt_type and model_type in ["vace_standin_14B"]: + print("Warning: Missing Standin ref image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images.") + else: + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + face_processor = FaceProcessor() + standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"]) + face_processor = None + gc.collect() + torch.cuda.empty_cache() + standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) + standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) + kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) # Steps Skipping @@ -819,7 +746,7 @@ class WanAny2V: denoising_extra = "" from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps - phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) + phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) if len(phases_description) > 0: set_header_text(phases_description) guidance_switch_done = guidance_switch2_done = False if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}" @@ -830,8 +757,8 @@ class WanAny2V: denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}" callback(step_no-1, denoising_extra = denoising_extra) return guide_scale, guidance_switch_done, trans, denoising_extra - update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) - if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) def clear(): @@ -844,19 +771,22 @@ class WanAny2V: scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} # b, c, lat_f, lat_h, lat_w latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + if "G" in video_prompt_type: randn = latents if apg_switch != 0: apg_momentum = -0.75 apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None + gc.collect() + torch.cuda.empty_cache() # denoising trans = self.model for i, t in enumerate(tqdm(timesteps)): guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra) guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra) - offload.set_step_no_for_lora(trans, i) + offload.set_step_no_for_lora(trans, start_step_no + i) timestep = torch.stack([t]) if timestep_injection: @@ -864,44 +794,43 @@ class WanAny2V: timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) timestep[:source_latents.shape[2]] = 0 - kwargs.update({"t": timestep, "current_step": start_step_no + i}) + kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i }) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: + if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 - noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: - new_latents = latents.clone() - new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents + noisy_image = latents.clone() + noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: - new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] - latents = new_latents - new_latents = None + noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] + latents = noisy_image + noisy_image = None else: - latents = noise * sigma + (1 - sigma) * source_latents - noise = None + latents = randn * sigma + (1 - sigma) * source_latents if extended_overlapped_latents != None: if no_noise_latents_injection: latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents else: latent_noise_factor = t / 1000 + latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor if vace: overlap_noise_factor = overlap_noise / 1000 for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor - if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) + if extended_input_dim > 0: + latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim) else: latent_model_input = latents any_guidance = guide_scale != 1 if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + - [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: @@ -1010,8 +939,8 @@ class WanAny2V: if sample_solver == "euler": dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) - dt = dt / self.num_timesteps - latents = latents - noise_pred * dt[:, None, None, None, None] + dt = dt.item() / self.num_timesteps + latents = latents - noise_pred * dt else: latents = sample_scheduler.step( noise_pred[:, :, :target_shape[1]], @@ -1019,9 +948,16 @@ class WanAny2V: latents, **scheduler_kwargs)[0] + + if image_mask_latents is not None: + sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000 + noisy_image = randn * sigma + (1 - sigma) * source_latents + latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents + + if callback is not None: latents_preview = latents - if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1032,7 +968,7 @@ class WanAny2V: if timestep_injection: latents[:, :, :source_latents.shape[2]] = source_latents - if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: latent_slice = latents[:, :, return_latent_slice].clone() @@ -1069,4 +1005,20 @@ class WanAny2V: delattr(model, "vace_blocks") + def adapt_animate_model(self, model): + modules_dict= { k: m for k, m in model.named_modules()} + for animate_layer in range(8): + module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"] + model_layer = animate_layer * 5 + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "face_adapter_fuser_blocks", module ) + delattr(model, "face_adapter") + + def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs): + if base_model_type == "animate": + if "1" in video_prompt_type: + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + if len(preloadURLs) > 0: + return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + return [], [] diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index bc79e2e..31d5da6 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -21,11 +21,24 @@ class family_handler(): extra_model_def["fps"] =fps extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 20 + extra_model_def["latent_size"] = 4 extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True extra_model_def["guidance_max_phases"] = 1 + extra_model_def["model_modes"] = { + "choices": [ + ("Synchronous", 0), + ("Asynchronous (better quality but around 50% extra steps added)", 5), + ], + "default": 0, + "label" : "Generation Type" + } + + extra_model_def["image_prompt_types_allowed"] = "TSV" + + return extra_model_def @staticmethod @@ -54,7 +67,11 @@ class family_handler(): def query_family_infos(): return {} - + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) + return latent_rgb_factors, latent_rgb_factors_bias @staticmethod def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): @@ -62,7 +79,7 @@ class family_handler(): return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS from .wan_handler import family_handler cfg = WAN_CONFIGS['t2v-14B'] diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py deleted file mode 100644 index 753fd45..0000000 --- a/models/wan/diffusion_forcing copy.py +++ /dev/null @@ -1,479 +0,0 @@ -import math -import os -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import logging -import numpy as np -import torch -from diffusers.image_processor import PipelineImageInput -from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor -from tqdm import tqdm -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - -class DTT2V: - - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16, - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json") - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - self.scheduler = FlowUniPCMultistepScheduler() - - @property - def do_classifier_free_guidance(self) -> bool: - return self._guidance_scale > 1 - - def encode_image( - self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # prefix_video - prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) - prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) - if prefix_video.dtype == torch.uint8: - prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 - prefix_video = prefix_video.to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - return prefix_video, predix_video_latent_length - - def prepare_latents( - self, - shape: Tuple[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - ) -> torch.Tensor: - return randn_tensor(shape, generator, device=device, dtype=dtype) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - @torch.no_grad() - def generate( - self, - prompt: Union[str, List[str]], - negative_prompt: Union[str, List[str]] = "", - image: PipelineImageInput = None, - height: int = 480, - width: int = 832, - num_frames: int = 97, - num_inference_steps: int = 50, - shift: float = 1.0, - guidance_scale: float = 5.0, - seed: float = 0.0, - overlap_history: int = 17, - addnoise_condition: int = 0, - base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: int = 1, - causal_attention: bool = False, - fps: int = 24, - VAE_tile_size = 0, - joint_pass = False, - callback = None, - ): - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - # if base_num_frames > base_num_frames: - # causal_block_size = 0 - self._guidance_scale = guidance_scale - - i2v_extra_kwrags = {} - prefix_video = None - predix_video_latent_length = 0 - if image: - frame_width, frame_height = image.size - scale = min(height / frame_height, width / frame_width) - height = (int(frame_height * scale) // 16) * 16 - width = (int(frame_width * scale) // 16) * 16 - - prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size) - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - - prompt_embeds = self.text_encoder([prompt], self.device) - prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds] - if self.do_classifier_free_guidance: - negative_prompt_embeds = self.text_encoder([negative_prompt], self.device) - negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds] - - - - self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - fps_embeds = [fps] * prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - transformer_dtype = self.dtype - # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # short video generation - latent_shape = [16, latent_length, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size - ) - sample_schedulers = [] - for _ in range(latent_length): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - # "causal_block_size" : causal_block_size, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0], False) - - x0 = latents[0].unsqueeze(0) - videos = self.vae.decode(x0, tile_size= VAE_tile_size) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - videos = [video.cpu().numpy().astype(np.uint8) for video in videos] - return videos - else: - # long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - print(f"n_iter:{n_iter}") - output_video = None - for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: # i == 0 - base_num_frames_iter = base_num_frames - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - return output_video diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 95faa4d..cd02470 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -12,29 +12,17 @@ from diffusers.models.modeling_utils import ModelMixin import numpy as np from typing import Union,Optional from mmgp import offload +from mmgp.offload import get_cache, clear_caches from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target +from ..animate.motion_encoder import Generator +from ..animate.face_blocks import FaceAdapter, FaceEncoder +from ..animate.model_animate import after_patch_embedding __all__ = ['WanModel'] -def get_cache(cache_name): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is None: - all_cache = {} - offload.shared_state["_cache"]= all_cache - cache = offload.shared_state.get(cache_name, None) - if cache is None: - cache = {} - offload.shared_state[cache_name] = cache - return cache - -def clear_caches(): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is not None: - all_cache.clear() - def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -514,6 +502,7 @@ class WanAttentionBlock(nn.Module): multitalk_masks=None, ref_images_count=0, standin_phase=-1, + motion_vec = None, ): r""" Args: @@ -579,19 +568,23 @@ class WanAttentionBlock(nn.Module): y = self.norm_x(x) y = y.to(attention_dtype) if ref_images_count == 0: - x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + del y + x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) else: y_shape = y.shape y = y.reshape(y_shape[0], grid_sizes[0], -1) y = y[:, ref_images_count:] y = y.reshape(y_shape[0], -1, y_shape[-1]) grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] - y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + y = None + y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) x = x.reshape(y_shape[0], grid_sizes[0], -1) x[:, ref_images_count:] += y x = x.reshape(y_shape[0], -1, y_shape[-1]) - del y + del y y = self.norm2(x) @@ -627,6 +620,10 @@ class WanAttentionBlock(nn.Module): x.add_(hint) else: x.add_(hint, alpha= scale) + + if motion_vec is not None and self.block_no % 5 == 0: + x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False) + return x class AudioProjModel(ModelMixin, ConfigMixin): @@ -909,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin): norm_input_visual=True, norm_output_audio=True, standin= False, + motion_encoder_dim=0, ): super().__init__() @@ -933,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin): self.flag_causal_attention = False self.block_mask = None self.inject_sample_info = inject_sample_info - + self.motion_encoder_dim = motion_encoder_dim self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale multitalk = multitalk_output_dim > 0 - self.multitalk = multitalk + self.multitalk = multitalk + animate = motion_encoder_dim > 0 # embeddings self.patch_embedding = nn.Conv3d( @@ -1038,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin): block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if animate: + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1202,7 +1220,8 @@ class WanModel(ModelMixin, ConfigMixin): y=None, freqs = None, pipeline = None, - current_step = 0, + current_step_no = 0, + real_step_no = 0, x_id= 0, max_steps = 0, slg_layers=None, @@ -1219,6 +1238,9 @@ class WanModel(ModelMixin, ConfigMixin): ref_images_count = 0, standin_freqs = None, standin_ref = None, + pose_latents=None, + face_pixel_values=None, + ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1251,9 +1273,18 @@ class WanModel(ModelMixin, ConfigMixin): if bz > 1: y = y.expand(bz, -1, -1, -1, -1) x = torch.cat([x, y], dim=1) # embeddings - # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] + x_list[i] = x + y = None + + motion_vec_list = [] + for i, x in enumerate(x_list): + # animate embeddings + motion_vec = None + if pose_latents is not None: + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values) + motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) x_og_shape = x.shape @@ -1261,7 +1292,7 @@ class WanModel(ModelMixin, ConfigMixin): else: x = x.flatten(2).transpose(1, 2) x_list[i] = x - x, y = None, None + x = None block_mask = None @@ -1280,9 +1311,9 @@ class WanModel(ModelMixin, ConfigMixin): del causal_mask offload.shared_state["embed_sizes"] = grid_sizes - offload.shared_state["step_no"] = current_step + offload.shared_state["step_no"] = real_step_no offload.shared_state["max_steps"] = max_steps - if current_step == 0 and x_id == 0: clear_caches() + if current_step_no == 0 and x_id == 0: clear_caches() # arguments kwargs = dict( @@ -1306,7 +1337,7 @@ class WanModel(ModelMixin, ConfigMixin): if standin_ref is not None: standin_cache_enabled = False kwargs["standin_phase"] = 2 - if (current_step == 0 or not standin_cache_enabled) and x_id == 0: + if current_step_no == 0 or not standin_cache_enabled : standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) @@ -1371,7 +1402,7 @@ class WanModel(ModelMixin, ConfigMixin): skips_steps_cache = self.cache if skips_steps_cache != None: if skips_steps_cache.cache_type == "mag": - if current_step <= skips_steps_cache.start_step: + if real_step_no <= skips_steps_cache.start_step: should_calc = True elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all assert len(x_list) == 1 @@ -1380,7 +1411,7 @@ class WanModel(ModelMixin, ConfigMixin): x_should_calc = [] for i in range(1 if skips_steps_cache.one_for_all else len(x_list)): cur_x_id = i if joint_pass else x_id - cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list + cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps @@ -1400,7 +1431,7 @@ class WanModel(ModelMixin, ConfigMixin): if x_id != 0: should_calc = skips_steps_cache.should_calc else: - if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1: + if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1: should_calc = True skips_steps_cache.accumulated_rel_l1_distance = 0 else: @@ -1453,7 +1484,7 @@ class WanModel(ModelMixin, ConfigMixin): return [None] * len(x_list) if standin_x is not None: - if not standin_cache_enabled and x_id ==0 : get_cache("standin").clear() + if not standin_cache_enabled: get_cache("standin").clear() standin_x = block(standin_x, context = None, grid_sizes = None, e= standin_e0, freqs = standin_freqs, standin_phase = 1) if slg_layers is not None and block_idx in slg_layers: @@ -1461,9 +1492,9 @@ class WanModel(ModelMixin, ConfigMixin): continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs) del x context = hints = audio_embedding = None diff --git a/models/wan/multitalk/attention.py b/models/wan/multitalk/attention.py index 27d488f..669b5c1 100644 --- a/models/wan/multitalk/attention.py +++ b/models/wan/multitalk/attention.py @@ -221,13 +221,16 @@ class SingleStreamAttention(nn.Module): self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: + def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: N_t, N_h, N_w = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -247,9 +250,6 @@ class SingleStreamAttention(nn.Module): q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - - attn_bias = None - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) @@ -302,7 +302,7 @@ class SingleStreamMutiAttention(SingleStreamAttention): self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) def forward(self, - x: torch.Tensor, + xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, x_ref_attn_map=None, @@ -310,14 +310,17 @@ class SingleStreamMutiAttention(SingleStreamAttention): encoder_hidden_states = encoder_hidden_states.squeeze(0) if x_ref_attn_map == None: - return super().forward(x, encoder_hidden_states, shape) + return super().forward(xlist, encoder_hidden_states, shape) N_t, _, _ = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -339,7 +342,9 @@ class SingleStreamMutiAttention(SingleStreamAttention): normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - q = self.rope_1d(q, normalized_pos) + qlist = [q] + del q + q = self.rope_1d(qlist, normalized_pos, "q") q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) _, N_a, _ = encoder_hidden_states.shape @@ -347,7 +352,7 @@ class SingleStreamMutiAttention(SingleStreamAttention): encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) encoder_k, encoder_v = encoder_kv.unbind(0) - + del encoder_kv if self.qk_norm: encoder_k = self.add_k_norm(encoder_k) @@ -356,13 +361,14 @@ class SingleStreamMutiAttention(SingleStreamAttention): per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 encoder_pos = torch.concat([per_frame]*N_t, dim=0) encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - encoder_k = self.rope_1d(encoder_k, encoder_pos) + enclist = [encoder_k] + del encoder_k + encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k") encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index fbf9175..52d2dd9 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160 audio_emb = audio_emb.cpu().detach() return audio_emb - + +def extract_audio_from_video(filename, sample_rate): + raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav' + ffmpeg_command = [ + "ffmpeg", + "-y", + "-i", + str(filename), + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "2", + str(raw_audio_path), + ] + subprocess.run(ffmpeg_command, check=True) + human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + os.remove(raw_audio_path) + + return human_speech_array + def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): ext = os.path.splitext(audio_path)[1].lower() if ext in ['.mp4', '.mov', '.avi', '.mkv']: @@ -191,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2): return s1, s2, save_path_sum -def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0): +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") pad = int(padded_frames_for_embeddings/ fps * sr) new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) - audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - full_audio_embs = [] - if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) - if audio_guide2 == None and not duration_changed: sum_human_speechs = None + if return_sum_only: + full_audio_embs = None + else: + audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + full_audio_embs = [] + if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) + if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) + if audio_guide2 == None and not duration_changed: sum_human_speechs = None return full_audio_embs, sum_human_speechs diff --git a/models/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py index 6e2b2c3..7851c45 100644 --- a/models/wan/multitalk/multitalk_utils.py +++ b/models/wan/multitalk/multitalk_utils.py @@ -16,7 +16,7 @@ import torchvision import binascii import os.path as osp from skimage import color - +from mmgp.offload import get_cache, clear_caches VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") ASPECT_RATIO_627 = { @@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): # @torch.compile -def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): - - ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) +def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q * scale + visual_q = visual_q.transpose(1, 2) + ref_k = ref_k.transpose(1, 2) + visual_q_shape = visual_q.shape + visual_q = visual_q.view(-1, visual_q_shape[-1] ) + number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2 + chunk_size = int(visual_q_shape[-2] / number_chunks) + chunks =torch.split(visual_q, chunk_size) + maps_lists = [ [] for _ in ref_target_masks] + for q_chunk in chunks: + attn = q_chunk @ ref_k.transpose(-2, -1) + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + del attn + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) + + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask[None, None, None, ...] + x_ref_attnmap = x_ref_attn_map_source * ref_target_mask + x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens + maps_lists[class_idx].append(x_ref_attnmap) + + del x_ref_attn_map_source + + x_ref_attn_maps = [] + for class_idx, maps_list in enumerate(maps_lists): + attn_map_fuse = torch.concat(maps_list, dim= -1) + attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1) + x_ref_attn_maps.append( attn_map_fuse ) + + + return torch.concat(x_ref_attn_maps, dim=0) + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) scale = 1.0 / visual_q.shape[-1] ** 0.5 visual_q = visual_q * scale visual_q = visual_q.transpose(1, 2) ref_k = ref_k.transpose(1, 2) attn = visual_q @ ref_k.transpose(-2, -1) - if attn_bias is not None: attn += attn_bias - x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens - + del attn x_ref_attn_maps = [] - ref_target_masks = ref_target_masks.to(visual_q.dtype) - x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) for class_idx, ref_target_mask in enumerate(ref_target_masks): ref_target_mask = ref_target_mask[None, None, None, ...] x_ref_attnmap = x_ref_attn_map_source * ref_target_mask x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H - - if mode == 'mean': - x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens - elif mode == 'max': - x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens - + x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads) x_ref_attn_maps.append(x_ref_attnmap) - del attn del x_ref_attn_map_source return torch.concat(x_ref_attn_maps, dim=0) - def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): """Args: query (torch.tensor): B M H K @@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli N_t, N_h, N_w = shape x_seqlens = N_h * N_w + if x_seqlens <= 1508: + split_num = 10 # 540p + else: + split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p + ref_k = ref_k[:, :x_seqlens] if ref_images_count > 0 : visual_q_shape = visual_q.shape @@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli split_chunk = heads // split_num - for i in range(split_num): - x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) - x_ref_attn_maps += x_ref_attn_maps_perhead + if split_chunk == 1: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead + else: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead x_ref_attn_maps /= split_num return x_ref_attn_maps @@ -158,7 +196,6 @@ class RotaryPositionalEmbedding1D(nn.Module): self.base = 10000 - @lru_cache(maxsize=32) def precompute_freqs_cis_1d(self, pos_indices): freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) @@ -167,7 +204,7 @@ class RotaryPositionalEmbedding1D(nn.Module): freqs = repeat(freqs, "... n -> ... (n r)", r=2) return freqs - def forward(self, x, pos_indices): + def forward(self, qlist, pos_indices, cache_entry = None): """1D RoPE. Args: @@ -176,16 +213,26 @@ class RotaryPositionalEmbedding1D(nn.Module): Returns: query with the same shape as input. """ - freqs_cis = self.precompute_freqs_cis_1d(pos_indices) - - x_ = x.float() - - freqs_cis = freqs_cis.float().to(x.device) - cos, sin = freqs_cis.cos(), freqs_cis.sin() - cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') - x_ = (x_ * cos) + (rotate_half(x_) * sin) - - return x_.type_as(x) + xq= qlist[0] + qlist.clear() + cache = get_cache("multitalk_rope") + freqs_cis= cache.get(cache_entry, None) + if freqs_cis is None: + freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices) + cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0) + # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + # real * cos - imag * sin + # imag * cos + real * sin + xq_dtype = xq.dtype + xq_out = xq.to(torch.float) + xq = None + xq_rot = rotate_half(xq_out) + xq_out *= cos + xq_rot *= sin + xq_out += xq_rot + del xq_rot + xq_out = xq_out.to(xq_dtype) + return xq_out diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py deleted file mode 100644 index 8af9458..0000000 --- a/models/wan/text2video fuse attempt.py +++ /dev/null @@ -1,698 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import gc -import logging -import math -import os -import random -import sys -import types -from contextlib import contextmanager -from functools import partial -from mmgp import offload -import torch -import torch.nn as nn -import torch.cuda.amp as amp -import torch.distributed as dist -from tqdm import tqdm -from PIL import Image -import torchvision.transforms.functional as TF -import torch.nn.functional as F -from .distributed.fsdp import shard_model -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.vace_preprocessor import VaceVideoProcessor - - -def optimized_scale(positive_flat, negative_flat): - - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - - return st_star - - -class WanT2V: - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16 - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - - self.sample_neg_prompt = config.sample_neg_prompt - - if "Vace" in model_filename: - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=480*832, - max_area=480*832, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=32760, - keep_last=True) - - self.adapt_vace_model() - - self.scheduler = FlowUniPCMultistepScheduler() - - def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) - - if masks is None: - latents = self.vae.encode(frames, tile_size = tile_size) - else: - inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] - reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] - inactive = self.vae.encode(inactive, tile_size = tile_size) - reactive = self.vae.encode(reactive, tile_size = tile_size) - latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] - - cat_latents = [] - for latent, refs in zip(latents, ref_images): - if refs is not None: - if masks is None: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - else: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] - assert all([x.shape[1] == 1 for x in ref_latent]) - latent = torch.cat([*ref_latent, latent], dim=1) - cat_latents.append(latent) - return cat_latents - - def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - - result_masks = [] - for mask, refs in zip(masks, ref_images): - c, depth, height, width = mask.shape - new_depth = int((depth + 3) // self.vae_stride[0]) - height = 2 * (int(height) // (self.vae_stride[1] * 2)) - width = 2 * (int(width) // (self.vae_stride[2] * 2)) - - # reshape - mask = mask[0, :, :, :] - mask = mask.view( - depth, height, self.vae_stride[1], width, self.vae_stride[1] - ) # depth, height, 8, width, 8 - mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width - mask = mask.reshape( - self.vae_stride[1] * self.vae_stride[2], depth, height, width - ) # 8*8, depth, height, width - - # interpolation - mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) - - if refs is not None: - length = len(refs) - mask_pad = torch.zeros_like(mask[:, :length, :, :]) - mask = torch.cat((mask_pad, mask), dim=1) - result_masks.append(mask) - return result_masks - - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): - image_sizes = [] - trim_video = len(keep_frames) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - src_video[i] = src_video[i].to(device) - src_mask[i] = src_mask[i].to(device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - src_video[i] = src_video[i].to(device) - src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_frames): - if not keep: - src_video[i][:, k:k+1] = 0 - src_mask[i][:, k:k+1] = 1 - - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if ref_img.shape[-2:] != image_size: - canvas_height, canvas_width = image_size - ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image - ref_img = white_canvas - src_ref_images[i][j] = ref_img.to(device) - return src_video, src_mask, src_ref_images - - def decode_latent(self, zs, ref_images=None, tile_size= 0 ): - if ref_images is None: - ref_images = [None] * len(zs) - else: - assert len(zs) == len(ref_images) - - trimed_zs = [] - for z, refs in zip(zs, ref_images): - if refs is not None: - z = z[:, len(refs):, :, :] - trimed_zs.append(z) - - return self.vae.decode(trimed_zs, tile_size= tile_size) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - def generate(self, - input_prompt, - input_frames= None, - input_masks = None, - input_ref_images = None, - source_video=None, - target_camera=None, - context_scale=1.0, - size=(1280, 720), - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = None, - VAE_tile_size = 0, - joint_pass = False, - slg_layers = None, - slg_start = 0.0, - slg_end = 1.0, - cfg_star_switch = True, - cfg_zero_step = 5, - ): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (tupele[`int`], *optional*, defaults to (1280,720)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - - frame_num = max(17, frame_num) # must match causal_block_size for value of 5 - frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) - num_frames = frame_num - addnoise_condition = 20 - causal_attention = True - fps = 16 - ar_step = 5 - - - - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if target_camera != None: - size = (source_video.shape[2], source_video.shape[1]) - source_video = source_video.to(dtype=self.dtype , device=self.device) - source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) - del source_video - # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding - cam_emb = get_camera_embedding(target_camera) - cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) - - if input_frames != None: - # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) - m0 = self.vace_encode_masks(input_masks, input_ref_images) - z = self.vace_latent(z0, m0) - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - else: - F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1]) - - context = [u.to(self.dtype) for u in context] - context_null = [u.to(self.dtype) for u in context_null] - - noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] - - # evaluation mode - - # if sample_solver == 'unipc': - # sample_scheduler = FlowUniPCMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sample_scheduler.set_timesteps( - # sampling_steps, device=self.device, shift=shift) - # timesteps = sample_scheduler.timesteps - # elif sample_solver == 'dpm++': - # sample_scheduler = FlowDPMSolverMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - # timesteps, _ = retrieve_timesteps( - # sample_scheduler, - # device=self.device, - # sigmas=sampling_sigmas) - # else: - # raise NotImplementedError("Unsupported solver.") - - # sample videos - latents = noise - del noise - batch_size =len(latents) - if target_camera != None: - shape = list(latents[0].shape[1:]) - shape[0] *= 2 - freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) - else: - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) - # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - - i2v_extra_kwrags = {} - - if target_camera != None: - recam_dict = {'cam_emb': cam_emb} - i2v_extra_kwrags.update(recam_dict) - - if input_frames != None: - vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} - i2v_extra_kwrags.update(vace_dict) - - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - if ar_step == 0: - causal_block_size = 1 - fps_embeds = [fps] #* prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - - self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - base_num_frames_iter = latent_length - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - - prefix_video = None - predix_video_latent_length = 0 - - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - - updated_num_steps= len(step_matrix) - - if callback != None: - callback(-1, None, True, override_num_inference_steps = updated_num_steps) - if self.model.enable_teacache: - self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) - # if callback != None: - # callback(-1, None, True) - - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self, - "current_step" : i, - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=context, - context2=context_null, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=context_null, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=seed_g, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - # for i, t in enumerate(tqdm(timesteps)): - # if target_camera != None: - # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] - # else: - # latent_model_input = latents - # slg_layers_local = None - # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): - # slg_layers_local = slg_layers - # timestep = [t] - # offload.set_step_no_for_lora(self.model, i) - # timestep = torch.stack(timestep) - - # if joint_pass: - # noise_pred_cond, noise_pred_uncond = self.model( - # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) - # if self._interrupt: - # return None - # else: - # noise_pred_cond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] - # if self._interrupt: - # return None - # noise_pred_uncond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] - # if self._interrupt: - # return None - - # # del latent_model_input - - # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - # noise_pred_text = noise_pred_cond - # if cfg_star_switch: - # positive_flat = noise_pred_text.view(batch_size, -1) - # negative_flat = noise_pred_uncond.view(batch_size, -1) - - # alpha = optimized_scale(positive_flat,negative_flat) - # alpha = alpha.view(batch_size, 1, 1, 1) - - # if (i <= cfg_zero_step): - # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... - # else: - # noise_pred_uncond *= alpha - # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) - # del noise_pred_uncond - - # temp_x0 = sample_scheduler.step( - # noise_pred[:, :target_shape[1]].unsqueeze(0), - # t, - # latents[0].unsqueeze(0), - # return_dict=False, - # generator=seed_g)[0] - # latents = [temp_x0.squeeze(0)] - # del temp_x0 - - # if callback is not None: - # callback(i, latents[0], False) - - x0 = latents - - if input_frames == None: - videos = self.vae.decode(x0, VAE_tile_size) - else: - videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) - - del latents - del sample_scheduler - - return videos[0] if self.rank == 0 else None - - def adapt_vace_model(self): - model = self.model - modules_dict= { k: m for k, m in model.named_modules()} - for model_layer, vace_layer in model.vace_layers_mapping.items(): - module = modules_dict[f"vace_blocks.{vace_layer}"] - target = modules_dict[f"blocks.{model_layer}"] - setattr(target, "vace", module ) - delattr(model, "vace_blocks") - - \ No newline at end of file diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c3ad012..c587544 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -1,8 +1,12 @@ import torch import numpy as np +import gradio as gr def test_class_i2v(base_model_type): - return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] + +def text_oneframe_overlap(base_model_type): + return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type) def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -13,6 +17,8 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_wan_5B(base_model_type): + return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @staticmethod @@ -32,7 +38,7 @@ class family_handler(): def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] elif base_model_type in ["i2v_2_2"]: def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] else: # i2v @@ -79,11 +85,13 @@ class family_handler(): vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] extra_model_def["vace_class"] = vace_class - if test_multitalk(base_model_type): + if base_model_type in ["animate"]: + fps = 30 + elif test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): fps = 24 else: fps = 16 @@ -96,14 +104,14 @@ class family_handler(): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "sample_solvers":[ @@ -112,9 +120,169 @@ class family_handler(): ("dpm++", "dpm++"), ("flowmatch causvid", "causvid"), ] }) + + + if base_model_type in ["t2v"]: + extra_model_def["guide_custom_choices"] = { + "choices":[("Use Text Prompt Only", ""), + ("Video to Video guided by Text Prompt", "GUV"), + ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")], + "default": "", + "letters_filter": "GUVA", + "label": "Video to Video" + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A"], + "visible": False + } + + if base_model_type in ["infinitetalk"]: extra_model_def["no_background_removal"] = True + extra_model_def["all_image_refs_are_background_ref"] = True + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"), + ], + "default": "KI", + "letters_filter": "RGUVQKI", + "label": "Video to Video", + "show_label" : False, + } + # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["lucy_edit"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible": False, + } + + if base_model_type in ["animate"]: + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"), + ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video by Person in Reference Image", "PVBAI"), + ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"), + ], + "default": "PVBKI", + "letters_filter": "PVBXAKI1", + "label": "Type of Process", + "show_label" : False, + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A", "XA"], + "visible": False + } + + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["extract_guide_from_window_start"] = True + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" + extra_model_def["background_ref_outpainted"] = False + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["guide_inpaint_color"] = 0 + + + + if vace_class: + extra_model_def["guide_preprocessing"] = { + "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], + "labels" : { "V": "Use Vace raw format"} + } + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"], + } + + extra_model_def["image_ref_choices"] = { + "choices": [("None", ""), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), + ], + "letters_filter": "KFI", + } + + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["pad_guide_video"] = True + extra_model_def["guide_inpaint_color"] = 127.5 + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["return_image_refs_tensor"] = True + + if base_model_type in ["standin"]: + extra_model_def["fit_into_canvas_image_refs"] = 0 + extra_model_def["image_ref_choices"] = { + "choices": [ + ("No Reference Image", ""), + ("Reference Image is a Person Face", "I"), + ], + "letters_filter":"I", + } + + if base_model_type in ["phantom_1.3B", "phantom_14B"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Reference Image", "I")], + "letters_filter":"I", + "visible": False, + } + + if base_model_type in ["recam_1.3B"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["model_modes"] = { + "choices": [ + ("Pan Right", 1), + ("Pan Left", 2), + ("Tilt Up", 3), + ("Tilt Down", 4), + ("Zoom In", 5), + ("Zoom Out", 6), + ("Translate Up (with rotation)", 7), + ("Translate Down (with rotation)", 8), + ("Arc Left (with rotation)", 9), + ("Arc Right (with rotation)", 10), + ], + "default": 1, + "label" : "Camera Movement Type" + } + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible" : False, + } + + if vace_class or base_model_type in ["animate"]: + image_prompt_types_allowed = "TVL" + elif base_model_type in ["infinitetalk"]: + image_prompt_types_allowed = "TSVL" + elif base_model_type in ["ti2v_2_2"]: + image_prompt_types_allowed = "TSVL" + elif base_model_type in ["lucy_edit"]: + image_prompt_types_allowed = "TVL" + elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: + image_prompt_types_allowed = "SVL" + elif i2v: + image_prompt_types_allowed = "SEVL" + else: + image_prompt_types_allowed = "" + extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed + + if text_oneframe_overlap(base_model_type): + extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1} + + # if base_model_type in ["phantom_1.3B", "phantom_14B"]: + # extra_model_def["one_image_ref_needed"] = True + return extra_model_def @@ -122,8 +290,8 @@ class family_handler(): def query_supported_types(): return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @staticmethod @@ -153,11 +321,12 @@ class family_handler(): @staticmethod def get_vae_block_size(base_model_type): - return 32 if base_model_type == "ti2v_2_2" else 16 + return 32 if test_wan_5B(base_model_type) else 16 @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors + if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2" latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @@ -171,7 +340,7 @@ class family_handler(): "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] }] - if base_model_type == "ti2v_2_2": + if test_wan_5B(base_model_type): download_def += [ { "repoId" : "DeepBeepMeep/Wan2.2", "sourceFolderList" : [""], @@ -182,7 +351,7 @@ class family_handler(): @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS if test_class_i2v(base_model_type): @@ -195,6 +364,7 @@ class family_handler(): config=cfg, checkpoint_dir="ckpts", model_filename=model_filename, + submodel_no_list = submodel_no_list, model_type = model_type, model_def = model_def, base_model_type=base_model_type, @@ -235,15 +405,42 @@ class family_handler(): if "I" in video_prompt_type: video_prompt_type = video_prompt_type.replace("KI", "QKI") ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.28: + if base_model_type in "infinitetalk": + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "U" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("U", "RU") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.31: + if base_model_type in "recam_1.3B": + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if not "V" in video_prompt_type: + video_prompt_type += "UV" + ui_defaults["video_prompt_type"] = video_prompt_type + ui_defaults["image_prompt_type"] = "" + + if text_oneframe_overlap(base_model_type): + ui_defaults["sliding_window_overlap"] = 1 + + if settings_version < 2.32: + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if test_class_i2v(base_model_type) and len(image_prompt_type) == 0: + ui_defaults["image_prompt_type"] = "S" + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "sample_solver": "unipc", }) + if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]: + ui_defaults["image_prompt_type"] = "S" + if base_model_type in ["fantasy"]: ui_defaults.update({ "audio_guidance_scale": 5.0, - "sliding_window_size": 1, + "sliding_window_overlap" : 1, }) elif base_model_type in ["multitalk"]: @@ -260,6 +457,7 @@ class family_handler(): "guidance_scale": 5.0, "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, + "sliding_window_size": 81, "sample_solver" : "euler", "video_prompt_type": "QKI", "remove_background_images_ref" : 0, @@ -293,6 +491,21 @@ class family_handler(): "image_prompt_type": "T", }) + if base_model_type in ["recam_1.3B", "lucy_edit"]: + ui_defaults.update({ + "video_prompt_type": "UV", + }) + elif base_model_type in ["animate"]: + ui_defaults.update({ + "video_prompt_type": "PVBKI", + "mask_expand": 20, + "audio_prompt_type": "R", + }) + + if text_oneframe_overlap(base_model_type): + ui_defaults["sliding_window_overlap"] = 1 + ui_defaults["color_correction_strength"]= 0 + if test_multitalk(base_model_type): ui_defaults["audio_guidance_scale"] = 4 @@ -309,3 +522,11 @@ class family_handler(): if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None: video_prompt_type = video_prompt_type.replace("I", "").replace("K","") inputs["video_prompt_type"] = video_prompt_type + + + if base_model_type in ["vace_standin_14B"]: + image_refs = inputs["image_refs"] + video_prompt_type = inputs["video_prompt_type"] + if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type: + gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.") + diff --git a/preprocessing/extract_vocals.py b/preprocessing/extract_vocals.py new file mode 100644 index 0000000..6564026 --- /dev/null +++ b/preprocessing/extract_vocals.py @@ -0,0 +1,69 @@ +from pathlib import Path +import os, tempfile +import numpy as np +import soundfile as sf +import librosa +import torch +import gc + +from audio_separator.separator import Separator + +def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str: + """ + If the source audio is shorter than `min_seconds`, pad with trailing silence + in a temporary file, then run separation and save only the vocals to dst_path. + Returns the full path to the vocals file. + """ + + default_device = torch.get_default_device() + torch.set_default_device('cpu') + + dst = Path(dst_path) + dst.parent.mkdir(parents=True, exist_ok=True) + + # Quick duration check + duration = librosa.get_duration(path=src_path) + + use_path = src_path + temp_path = None + try: + if duration < min_seconds: + # Load (resample) and pad in memory + y, sr = librosa.load(src_path, sr=None, mono=False) + if y.ndim == 1: # ensure shape (channels, samples) + y = y[np.newaxis, :] + target_len = int(min_seconds * sr) + pad = max(0, target_len - y.shape[1]) + if pad: + y = np.pad(y, ((0, 0), (0, pad)), mode="constant") + + # Write a temp WAV for the separator + fd, temp_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + sf.write(temp_path, y.T, sr) # soundfile expects (frames, channels) + use_path = temp_path + + # Run separation: emit only the vocals, with your exact filename + sep = Separator( + output_dir=str(dst.parent), + output_format=(dst.suffix.lstrip(".") or "wav"), + output_single_stem="Vocals", + model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt" + ) + sep.load_model() + out_files = sep.separate(use_path, {"Vocals": dst.stem}) + + out = Path(out_files[0]) + return str(out if out.is_absolute() else (dst.parent / out)) + finally: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + torch.cuda.empty_cache() + gc.collect() + torch.set_default_device(default_device) + +# Example: +# final = extract_vocals("in/clip.mp3", "out/vocals.wav") +# print(final) + diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index 151d9be..a5d5570 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -7,7 +7,6 @@ import psutil # import ffmpeg import imageio from PIL import Image - import cv2 import torch import torch.nn.functional as F @@ -22,6 +21,7 @@ from .utils.get_default_model import get_matanyone_model from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone from shared.utils.audio_video import save_video, save_image +from mmgp import offload arg_device = "cuda" arg_sam_model_type="vit_h" @@ -33,6 +33,8 @@ model_in_GPU = False matanyone_in_GPU = False bfloat16_supported = False # SAM generator +import copy + class MaskGenerator(): def __init__(self, sam_checkpoint, device): global args_device @@ -89,6 +91,7 @@ def get_frames_from_image(image_input, image_state): "last_frame_numer": 0, "fps": None } + image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) set_image_encoder_patch() select_SAM() @@ -537,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive file_name = ".".join(file_name.split(".")[:-1]) from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files - source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) + source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel ) output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" if len(source_audio_tracks) == 0: @@ -677,7 +680,6 @@ def load_unload_models(selected): } # os.path.join('.') - from mmgp import offload # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") sam_checkpoint = None @@ -695,7 +697,8 @@ def load_unload_models(selected): model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device) model_in_GPU = True from .matanyone.model.matanyone import MatAnyone - matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + matanyone_model = MatAnyone.from_pretrained("ckpts/mask") # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) matanyone_model = matanyone_model.to("cpu").eval() @@ -717,27 +720,33 @@ def load_unload_models(selected): def get_vmc_event_handler(): return load_unload_models -def export_to_vace_video_input(foreground_video_output): - gr.Info("Masked Video Input transferred to Vace For Inpainting") - return "V#" + str(time.time()), foreground_video_output - -def export_image(image_refs, image_output): - gr.Info("Masked Image transferred to Current Video") +def export_image(state, image_output): + ui_settings = get_current_model_settings(state) + image_refs = ui_settings["image_refs"] if image_refs == None: image_refs =[] image_refs.append( image_output) - return image_refs + ui_settings["image_refs"] = image_refs + gr.Info("Masked Image transferred to Current Image Generator") + return time.time() -def export_image_mask(image_input, image_mask): - gr.Info("Input Image & Mask transferred to Current Video") - return Image.fromarray(image_input), image_mask +def export_image_mask(state, image_input, image_mask): + ui_settings = get_current_model_settings(state) + ui_settings["image_guide"] = Image.fromarray(image_input) + ui_settings["image_mask"] = image_mask + + gr.Info("Input Image & Mask transferred to Current Image Generator") + return time.time() -def export_to_current_video_engine( foreground_video_output, alpha_video_output): +def export_to_current_video_engine(state, foreground_video_output, alpha_video_output): + ui_settings = get_current_model_settings(state) + ui_settings["video_guide"] = foreground_video_output + ui_settings["video_mask"] = alpha_video_output + gr.Info("Original Video and Full Mask have been transferred") - # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output - return foreground_video_output, alpha_video_output + return time.time() def teleport_to_video_tab(tab_state): @@ -746,15 +755,29 @@ def teleport_to_video_tab(tab_state): return gr.Tabs(selected="video_gen") -def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): +def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) - global image_output_codec, video_output_codec + global image_output_codec, video_output_codec, get_current_model_settings + get_current_model_settings = get_current_model_settings_fn image_output_codec = server_config.get("image_output_codec", None) video_output_codec = server_config.get("video_output_codec", None) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" + click_brush_js = """ + () => { + setTimeout(() => { + const brushButton = document.querySelector('button[aria-label="Brush"]'); + if (brushButton) { + brushButton.click(); + console.log('Brush button clicked'); + } else { + console.log('Brush button not found'); + } + }, 1000); + } """ + # download assets gr.Markdown("Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep") @@ -871,7 +894,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") with gr.Row(): clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100) - add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100) + add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100) remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100) with gr.Row(): @@ -892,7 +915,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, with gr.Row(visible= True): export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False) - export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, + export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) @@ -1089,10 +1112,10 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, # with gr.Column(scale=2, visible= True): export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button") - export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, - fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) - export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, + export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) + export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, + fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]).then(fn=None, inputs=None, outputs=None, js=click_brush_js) # first step: get the image information extract_frames_button.click( @@ -1148,5 +1171,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] ) - - + nada = gr.State({}) + # clear input + gr.on( + triggers=[image_input.clear], #image_input.change, + fn=restart, + inputs=[], + outputs=[ + image_state, + interactive_state, + click_state, + foreground_image_output, alpha_image_output, + template_frame, + image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click, + add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title + ], + queue=False, + show_progress=False) + diff --git a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py index 8c857ea..8166a9f 100644 --- a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py +++ b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py @@ -2,7 +2,6 @@ import math import torch from typing import Optional, Union, Tuple - # @torch.jit.script def get_similarity(mk: torch.Tensor, ms: torch.Tensor, @@ -59,6 +58,7 @@ def get_similarity(mk: torch.Tensor, del two_ab # similarity = (-a_sq + two_ab) + similarity =similarity.float() if ms is not None: similarity *= ms similarity /= math.sqrt(CK) diff --git a/preprocessing/matanyone/matanyone_wrapper.py b/preprocessing/matanyone/matanyone_wrapper.py index 292465a..7a729ab 100644 --- a/preprocessing/matanyone/matanyone_wrapper.py +++ b/preprocessing/matanyone/matanyone_wrapper.py @@ -73,5 +73,5 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10): if ti > (n_warmup-1): frames.append((com_np*255).astype(np.uint8)) phas.append((pha*255).astype(np.uint8)) - + # phas.append(np.clip(pha * 255, 0, 255).astype(np.uint8)) return frames, phas \ No newline at end of file diff --git a/preprocessing/speakers_separator.py b/preprocessing/speakers_separator.py index 79cde9b..d5f2563 100644 --- a/preprocessing/speakers_separator.py +++ b/preprocessing/speakers_separator.py @@ -100,7 +100,7 @@ class OptimizedPyannote31SpeakerSeparator: self.hf_token = hf_token self._overlap_pipeline = None - def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: + def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]: """Optimized main separation function with memory management.""" xprint("Starting optimized audio separation...") self._current_audio_path = os.path.abspath(audio_path) @@ -128,7 +128,11 @@ class OptimizedPyannote31SpeakerSeparator: gc.collect() # Save outputs efficiently - output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2) + if audio_original_path is None: + waveform_original = waveform + else: + waveform_original, sample_rate = self.load_audio(audio_original_path) + output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2) return output_paths @@ -835,7 +839,7 @@ class OptimizedPyannote31SpeakerSeparator: for turn, _, speaker in diarization.itertracks(yield_label=True): xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s") -def extract_dual_audio(audio, output1, output2, verbose = False): +def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None): global verbose_output verbose_output = verbose separator = OptimizedPyannote31SpeakerSeparator( @@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False): import time start_time = time.time() - outputs = separator.separate_audio(audio, output1, output2) + outputs = separator.separate_audio(audio, output1, output2, audio_original) elapsed_time = time.time() - start_time xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") diff --git a/requirements.txt b/requirements.txt index 767a68d..d6f75f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,14 +20,16 @@ soundfile mutagen pyloudnorm librosa==0.11.0 +speechbrain==1.0.3 +audio-separator==0.36.1 # UI & interaction -gradio==5.23.0 +gradio==5.29.0 dashscope loguru # Vision & segmentation -opencv-python>=4.9.0.80 +opencv-python>=4.12.0.88 segment-anything rembg[gpu]==2.0.65 onnxruntime-gpu @@ -43,14 +45,14 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.5.10 -peft==0.17.0 +mmgp==3.6.0 +peft==0.15.0 matplotlib # Utilities ftfy piexif -pynvml +nvidia-ml-py misaki # Optional / commented out diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py index 6e865fa..8a870b4 100644 --- a/shared/RGB_factors.py +++ b/shared/RGB_factors.py @@ -1,6 +1,6 @@ # thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) def get_rgb_factors(model_family, model_type = None): - if model_family == "wan": + if model_family in ["wan", "qwen"]: if model_type =="ti2v_2_2": latent_channels = 48 latent_dimensions = 3 @@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None): [ 0.0249, -0.0469, -0.1703] ] - latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] else: latent_rgb_factors_bias = latent_rgb_factors = None return latent_rgb_factors, latent_rgb_factors_bias \ No newline at end of file diff --git a/shared/attention.py b/shared/attention.py index a95332d..cc6ece0 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -3,6 +3,7 @@ import torch from importlib.metadata import version from mmgp import offload import torch.nn.functional as F +import warnings major, minor = torch.cuda.get_device_capability(None) bfloat16_supported = major >= 8 @@ -42,34 +43,51 @@ except ImportError: sageattn_varlen_wrapper = None -import warnings - try: - from sageattention import sageattn - from .sage2_core import sageattn as alt_sageattn, is_sage2_supported + from .sage2_core import sageattn as sageattn2, is_sage2_supported sage2_supported = is_sage2_supported() except ImportError: - sageattn = None - alt_sageattn = None + sageattn2 = None sage2_supported = False -# @torch.compiler.disable() -def sageattn_wrapper( +@torch.compiler.disable() +def sageattn2_wrapper( qkv_list, attention_length ): q,k, v = qkv_list - if True: - qkv_list = [q,k,v] - del q, k ,v - o = alt_sageattn(qkv_list, tensor_layout="NHD") - else: - o = sageattn(q, k, v, tensor_layout="NHD") - del q, k ,v - + qkv_list = [q,k,v] + del q, k ,v + o = sageattn2(qkv_list, tensor_layout="NHD") qkv_list.clear() return o +try: + from sageattn import sageattn_blackwell as sageattn3 +except ImportError: + sageattn3 = None + +@torch.compiler.disable() +def sageattn3_wrapper( + qkv_list, + attention_length + ): + q,k, v = qkv_list + # qkv_list = [q,k,v] + # del q, k ,v + # o = sageattn3(qkv_list, tensor_layout="NHD") + q = q.transpose(1,2) + k = k.transpose(1,2) + v = v.transpose(1,2) + o = sageattn3(q, k, v) + o = o.transpose(1,2) + qkv_list.clear() + + return o + + + + # try: # if True: # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda @@ -94,7 +112,7 @@ def sageattn_wrapper( # return o # except ImportError: -# sageattn = sageattn_qk_int8_pv_fp8_window_cuda +# sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda @torch.compiler.disable() def sdpa_wrapper( @@ -124,21 +142,28 @@ def get_attention_modes(): ret.append("xformers") if sageattn_varlen_wrapper != None: ret.append("sage") - if sageattn != None and version("sageattention").startswith("2") : + if sageattn2 != None and version("sageattention").startswith("2") : ret.append("sage2") + if sageattn3 != None: # and version("sageattention").startswith("3") : + ret.append("sage3") return ret def get_supported_attention_modes(): ret = get_attention_modes() + major, minor = torch.cuda.get_device_capability() + if major < 10: + if "sage3" in ret: + ret.remove("sage3") + if not sage2_supported: if "sage2" in ret: ret.remove("sage2") - major, minor = torch.cuda.get_device_capability() if major < 7: if "sage" in ret: ret.remove("sage") + return ret __all__ = [ @@ -201,7 +226,7 @@ def pay_attention( from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG - if b > 1 and k_lens != None and attn in ("sage2", "sdpa"): + if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"): assert attention_mask == None # Poor's man var k len attention assert q_lens == None @@ -234,7 +259,7 @@ def pay_attention( q_chunks, k_chunks, v_chunks = None, None, None o = torch.cat(o, dim = 0) return o - elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"): + elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"): assert b == 1 szq = q_lens[0].item() if q_lens != None else lq szk = k_lens[0].item() if k_lens != None else lk @@ -284,13 +309,19 @@ def pay_attention( max_seqlen_q=lq, max_seqlen_kv=lk, ).unflatten(0, (b, lq)) + elif attn=="sage3": + import math + if cross_attn or True: + qkv_list = [q,k,v] + del q,k,v + x = sageattn3_wrapper(qkv_list, lq) elif attn=="sage2": import math if cross_attn or True: qkv_list = [q,k,v] del q,k,v - x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0) + x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0) # else: # layer = offload.shared_state["layer"] # embed_sizes = offload.shared_state["embed_sizes"] diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py new file mode 100644 index 0000000..608b176 --- /dev/null +++ b/shared/convert/convert_diffusers_to_flux.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" +Convert a Flux model from Diffusers (folder or single-file) into the original +single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. + +Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) +Output : /path/to/flux1-your-model.safetensors (transformer only) + +Usage: + python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors + python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors + # optional quantization: + # --fp8 (float8_e4m3fn, simple) + # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) +""" + +import argparse +import json +from pathlib import Path +from collections import OrderedDict + +import torch +from safetensors import safe_open +import safetensors.torch +from tqdm import tqdm + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("diffusers_path", type=str, + help="Path to Diffusers checkpoint folder OR a single .safetensors file.") + ap.add_argument("output_path", type=str, + help="Output .safetensors path for the Flux transformer.") + ap.add_argument("--fp8", action="store_true", + help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") + ap.add_argument("--fp8-scaled", action="store_true", + help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") + return ap.parse_args() + + +# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable). +DIFFUSERS_MAP = { + # global embeds + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + + # dual-stream (image/text) blocks + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + + # single-stream blocks + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + + # final + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift] + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +class DiffusersSource: + """ + Uniform interface over: + 1) Folder with index JSON + shards + 2) Folder with exactly one .safetensors (no index) + 3) Single .safetensors file + Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) + """ + + POSSIBLE_PREFIXES = ["", "model."] # try in this order + + def __init__(self, path: Path): + p = Path(path) + if p.is_dir(): + # use 'transformer' subfolder if present + if (p / "transformer").is_dir(): + p = p / "transformer" + self._init_from_dir(p) + elif p.is_file() and p.suffix == ".safetensors": + self._init_from_single_file(p) + else: + raise FileNotFoundError(f"Invalid path: {p}") + + # ---------- common helpers ---------- + + @staticmethod + def _strip_prefix(k: str) -> str: + return k[6:] if k.startswith("model.") else k + + def _resolve(self, want: str): + """ + Return the actual stored key matching `want` by trying known prefixes. + """ + for pref in self.POSSIBLE_PREFIXES: + k = pref + want + if k in self._all_keys: + return k + return None + + def has(self, want: str) -> bool: + return self._resolve(want) is not None + + def get(self, want: str) -> torch.Tensor: + real_key = self._resolve(want) + if real_key is None: + raise KeyError(f"Missing key: {want}") + return self._get_by_real_key(real_key).to("cpu") + + @property + def base_keys(self): + # keys without 'model.' prefix for scanning + return [self._strip_prefix(k) for k in self._all_keys] + + # ---------- modes ---------- + + def _init_from_single_file(self, file_path: Path): + self._mode = "single" + self._file = file_path + self._handle = safe_open(file_path, framework="pt", device="cpu") + self._all_keys = list(self._handle.keys()) + + def _get_by_real_key(real_key: str): + return self._handle.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + + def _init_from_dir(self, dpath: Path): + index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" + if index_json.exists(): + with open(index_json, "r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] # full mapping + self._mode = "sharded" + self._dpath = dpath + self._weight_map = {k: dpath / v for k, v in weight_map.items()} + self._all_keys = list(self._weight_map.keys()) + self._open_handles = {} + + def _get_by_real_key(real_key: str): + fpath = self._weight_map[real_key] + h = self._open_handles.get(fpath) + if h is None: + h = safe_open(fpath, framework="pt", device="cpu") + self._open_handles[fpath] = h + return h.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + return + + # no index: try exactly one safetensors in folder + files = sorted(dpath.glob("*.safetensors")) + if len(files) != 1: + raise FileNotFoundError( + f"No index found and {dpath} does not contain exactly one .safetensors file." + ) + self._init_from_single_file(files[0]) + + +def main(): + args = parse_args() + src = DiffusersSource(Path(args.diffusers_path)) + + # Count blocks by scanning base keys (with any 'model.' prefix removed) + num_dual = 0 + num_single = 0 + for k in src.base_keys: + if k.startswith("transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_dual = max(num_dual, i + 1) + except Exception: + pass + elif k.startswith("single_transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_single = max(num_single, i + 1) + except Exception: + pass + print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") + + # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0) + def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: + shift, scale = vec.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + orig = {} + + # Per-block (dual) + for b in range(num_dual): + prefix = f"transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("double_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Per-block (single) + for b in range(num_single): + prefix = f"single_transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("single_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Globals (non-block) + for okey, dvals in DIFFUSERS_MAP.items(): + if okey.startswith(("double_blocks.", "single_blocks.")): + continue + dkeys = dvals + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey] = src.get(dkeys[0]) + else: + orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves + if "final_layer.adaLN_modulation.1.weight" in orig: + orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.weight"] + ) + if "final_layer.adaLN_modulation.1.bias" in orig: + orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.bias"] + ) + + # Optional FP8 variants (experimental; not required for ComfyUI/BFL) + if args.fp8 or args.fp8_scaled: + dtype = torch.float8_e4m3fn # noqa + minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max + + def stochastic_round_to(t): + t = t.float().clamp(minv, maxv) + lower = torch.floor(t * 256) / 256 + upper = torch.ceil(t * 256) / 256 + prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) + rnd = torch.rand_like(t) + out = torch.where(rnd < prob, upper, lower) + return out.to(dtype) + + def scale_to_8bit(weight, target_max=416.0): + absmax = weight.abs().max() + scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) + scaled = (weight / scale).clamp(minv, maxv).to(dtype) + return scaled, scale + + scales = {} + for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): + t = orig[k] + if args.fp8: + orig[k] = stochastic_round_to(t) + else: + if k.endswith(".weight") and t.dim() == 2: + qt, s = scale_to_8bit(t) + orig[k] = qt + scales[k[:-len(".weight")] + ".scale_weight"] = s + else: + orig[k] = t.clamp(minv, maxv).to(dtype) + if args.fp8_scaled: + orig.update(scales) + orig["scaled_fp8"] = torch.tensor([], dtype=dtype) + else: + # Default: save in bfloat16 + for k in list(orig.keys()): + orig[k] = orig[k].to(torch.bfloat16).cpu() + + out_path = Path(args.output_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + meta = OrderedDict() + meta["format"] = "pt" + meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") + print(f"Saving transformer to: {out_path}") + safetensors.torch.save_file(orig, str(out_path), metadata=meta) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py new file mode 100644 index 0000000..757da34 --- /dev/null +++ b/shared/gradio/gallery.py @@ -0,0 +1,532 @@ +from __future__ import annotations +import os, io, tempfile, mimetypes +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal + +import gradio as gr +import PIL +import time +from PIL import Image as PILImage + +FilePath = str +ImageLike = Union["PIL.Image.Image", Any] + +IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"} +VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"} + +def get_state(state): + return state if isinstance(state, dict) else state.value + +def get_list( objs): + if objs is None: + return [] + return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs] + +def record_last_action(st, last_action): + st["last_action"] = last_action + st["last_time"] = time.time() +class AdvancedMediaGallery: + def __init__( + self, + label: str = "Media", + *, + media_mode: Literal["image", "video"] = "image", + height = None, + columns: Union[int, Tuple[int, ...]] = 6, + show_label: bool = True, + initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None, + elem_id: Optional[str] = None, + elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",), + accept_filter: bool = True, # restrict Add-button dialog to allowed extensions + single_image_mode: bool = False, # start in single-image mode (Add replaces) + ): + assert media_mode in ("image", "video") + self.label = label + self.media_mode = media_mode + self.height = height + self.columns = columns + self.show_label = show_label + self.elem_id = elem_id + self.elem_classes = list(elem_classes) if elem_classes else None + self.accept_filter = accept_filter + + items = self._normalize_initial(initial or [], media_mode) + + # Components (filled on mount) + self.container: Optional[gr.Column] = None + self.gallery: Optional[gr.Gallery] = None + self.upload_btn: Optional[gr.UploadButton] = None + self.btn_remove: Optional[gr.Button] = None + self.btn_left: Optional[gr.Button] = None + self.btn_right: Optional[gr.Button] = None + self.btn_clear: Optional[gr.Button] = None + + # Single dict state + self.state: Optional[gr.State] = None + self._initial_state: Dict[str, Any] = { + "items": items, + "selected": (len(items) - 1) if items else 0, # None, + "single": bool(single_image_mode), + "mode": self.media_mode, + "last_action": "", + } + + # ---------------- helpers ---------------- + + def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]: + out: List[Any] = [] + if mode == "image": + for it in items: + p = self._ensure_image_item(it) + if p is not None: + out.append(p) + else: + for it in items: + if isinstance(item, tuple): item = item[0] + if isinstance(it, str) and self._is_video_path(it): + out.append(os.path.abspath(it)) + return out + + def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]: + # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path + if isinstance(item, tuple): item = item[0] + if isinstance(item, str): + return os.path.abspath(item) if self._is_image_path(item) else None + if PILImage is None: + return None + try: + if isinstance(item, PILImage.Image): + img = item + else: + import numpy as np # type: ignore + if isinstance(item, np.ndarray): + img = PILImage.fromarray(item) + elif hasattr(item, "read"): + data = item.read() + img = PILImage.open(io.BytesIO(data)).convert("RGBA") + else: + return None + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + img.save(tmp.name) + return tmp.name + except Exception: + return None + + @staticmethod + def _extract_path(obj: Any) -> Optional[str]: + # Try to get a filesystem path (for mode filtering); otherwise None. + if isinstance(obj, str): + return obj + try: + import pathlib + if isinstance(obj, pathlib.Path): # type: ignore + return str(obj) + except Exception: + pass + if isinstance(obj, dict): + return obj.get("path") or obj.get("name") + for attr in ("path", "name"): + if hasattr(obj, attr): + try: + val = getattr(obj, attr) + if isinstance(val, str): + return val + except Exception: + pass + return None + + @staticmethod + def _is_image_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in IMAGE_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("image/")) + + @staticmethod + def _is_video_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in VIDEO_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("video/")) + + def _filter_items_by_mode(self, items: List[Any]) -> List[Any]: + # Enforce image-only or video-only collection regardless of how files were added. + out: List[Any] = [] + if self.media_mode == "image": + for it in items: + p = self._extract_path(it) + if p is None: + # No path: likely an image object added programmatically => keep + out.append(it) + elif self._is_image_path(p): + out.append(os.path.abspath(p)) + else: + for it in items: + p = self._extract_path(it) + if p is not None and self._is_video_path(p): + out.append(os.path.abspath(p)) + return out + + @staticmethod + def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]: + # Keep it simple: dedupe by path when available, else allow duplicates. + seen_paths = set() + def key(x: Any) -> Optional[str]: + if isinstance(x, str): return os.path.abspath(x) + try: + import pathlib + if isinstance(x, pathlib.Path): # type: ignore + return os.path.abspath(str(x)) + except Exception: + pass + if isinstance(x, dict): + p = x.get("path") or x.get("name") + return os.path.abspath(p) if isinstance(p, str) else None + for attr in ("path", "name"): + if hasattr(x, attr): + try: + v = getattr(x, attr) + return os.path.abspath(v) if isinstance(v, str) else None + except Exception: + pass + return None + + out: List[Any] = [] + for lst in (cur, add): + for it in lst: + k = key(it) + if k is None or k not in seen_paths: + out.append(it) + if k is not None: + seen_paths.add(k) + return out + + @staticmethod + def _paths_from_payload(payload: Any) -> List[Any]: + # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly. + if payload is None: + return [] + if isinstance(payload, (list, tuple, set)): + return list(payload) + return [payload] + + # ---------------- event handlers ---------------- + + def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : + # Mirror the selected index into state and the gallery (server-side selected_index) + + st = get_state(state) + last_time = st.get("last_time", None) + if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery) + # print(f"ignored:{time.time()}, real {st['selected']}") + return gr.update(selected_index=st["selected"]), st + + idx = None + if evt is not None and hasattr(evt, "index"): + ix = evt.index + if isinstance(ix, int): + idx = ix + elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int): + if isinstance(self.columns, int) and len(ix) >= 2: + idx = ix[0] * max(1, int(self.columns)) + ix[1] + else: + idx = ix[0] + n = len(get_list(gallery)) + sel = idx if (idx is not None and 0 <= idx < n) else None + # print(f"image selected evt index:{sel}/{evt.selected}") + st["selected"] = sel + return gr.update(), st + + def _on_upload(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users upload via the Gallery itself. + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) + st = get_state(state) + new_items = self._paths_from_payload(items_filtered) + st["items"] = new_items + new_sel = len(new_items) - 1 + st["selected"] = new_sel + record_last_action(st,"add") + return gr.update(selected_index=new_sel), st + + def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users add/drag/drop/delete via the Gallery itself. + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) + st = get_state(state) + st["items"] = items_filtered + # Keep selection if still valid, else default to last + old_sel = st.get("selected", None) + if old_sel is None or not (0 <= old_sel < len(items_filtered)): + new_sel = (len(items_filtered) - 1) if items_filtered else None + else: + new_sel = old_sel + st["selected"] = new_sel + st["last_action"] ="gallery_change" + # print(f"gallery change: set sel {new_sel}") + return gr.update(selected_index=new_sel), st + + def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): + """ + Insert added items right AFTER the currently selected index. + Keeps the same ordering as chosen in the file picker, dedupes by path, + and re-selects the last inserted item. + """ + # New items (respect image/video mode) + # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) + new_items = self._paths_from_payload(files_payload) + + st = get_state(state) + cur: List[Any] = get_list(gallery) + sel = st.get("selected", None) + if sel is None: + sel = (len(cur) -1) if len(cur)>0 else 0 + single = bool(st.get("single", False)) + + # Nothing to add: keep as-is + if not new_items: + return gr.update(value=cur, selected_index=st.get("selected")), st + + # Single-image mode: replace + if single: + st["items"] = [new_items[-1]] + st["selected"] = 0 + return gr.update(value=st["items"], selected_index=0), st + + # ---------- helpers ---------- + def key_of(it: Any) -> Optional[str]: + # Prefer class helper if present + if hasattr(self, "_extract_path"): + p = self._extract_path(it) # type: ignore + else: + p = it if isinstance(it, str) else None + if p is None and isinstance(it, dict): + p = it.get("path") or it.get("name") + if p is None and hasattr(it, "path"): + try: p = getattr(it, "path") + except Exception: p = None + if p is None and hasattr(it, "name"): + try: p = getattr(it, "name") + except Exception: p = None + return os.path.abspath(p) if isinstance(p, str) else None + + # Dedupe the incoming batch by path, preserve order + seen_new = set() + incoming: List[Any] = [] + for it in new_items: + k = key_of(it) + if k is None or k not in seen_new: + incoming.append(it) + if k is not None: + seen_new.add(k) + + insert_pos = min(sel, len(cur) -1) + cur_clean = cur + # Build final list and selection + merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:] + new_sel = insert_pos + len(incoming) # select the last inserted item + + st["items"] = merged + st["selected"] = new_sel + record_last_action(st,"add") + # print(f"gallery add: set sel {new_sel}") + return gr.update(value=merged, selected_index=new_sel), st + + def _on_remove(self, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=st.get("selected")), st + items.pop(sel) + if not items: + st["items"] = []; st["selected"] = None + return gr.update(value=[], selected_index=None), st + new_sel = min(sel, len(items) - 1) + st["items"] = items; st["selected"] = new_sel + record_last_action(st,"remove") + # print(f"gallery del: new sel {new_sel}") + return gr.update(value=items, selected_index=new_sel), st + + def _on_move(self, delta: int, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=sel), st + j = sel + delta + if j < 0 or j >= len(items): + return gr.update(value=items, selected_index=sel), st + items[sel], items[j] = items[j], items[sel] + st["items"] = items; st["selected"] = j + record_last_action(st,"move") + # print(f"gallery move: set sel {j}") + return gr.update(value=items, selected_index=j), st + + def _on_clear(self, state: Dict[str, Any]) : + st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} + record_last_action(st,"clear") + # print(f"Clear all") + return gr.update(value=[], selected_index=None), st + + def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : + st = get_state(state); st["single"] = bool(to_single) + items: List[Any] = list(st["items"]); sel = st.get("selected", None) + if st["single"]: + keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None) + items = [keep] if keep is not None else [] + sel = 0 if items else None + st["items"] = items; st["selected"] = sel + + upload_update = gr.update(file_count=("single" if st["single"] else "multiple")) + left_update = gr.update(visible=not st["single"]) + right_update = gr.update(visible=not st["single"]) + clear_update = gr.update(visible=not st["single"]) + gallery_update= gr.update(value=items, selected_index=sel) + + return upload_update, left_update, right_update, clear_update, gallery_update, st + + # ---------------- build & wire ---------------- + + def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): + if parent is not None: + with parent: + col = self._build_ui(update_form) + else: + col = self._build_ui(update_form) + if not update_form: + self._wire_events() + return col + + def _build_ui(self, update = False) -> gr.Column: + with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: + self.container = col + + self.state = gr.State(dict(self._initial_state)) + + if update: + self.gallery = gr.update( + value=self._initial_state["items"], + selected_index=self._initial_state["selected"], # server-side selection + label=self.label, + show_label=self.show_label, + ) + else: + self.gallery = gr.Gallery( + value=self._initial_state["items"], + label=self.label, + height=self.height, + columns=self.columns, + show_label=self.show_label, + preview= True, + # type="pil", # very slow + file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), + selected_index=self._initial_state["selected"], # server-side selection + ) + + # One-line controls + exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None + with gr.Row(equal_height=True, elem_classes=["amg-controls"]): + self.upload_btn = gr.UploadButton( + "Set" if self._initial_state["single"] else "Add", + file_types=exts, + file_count=("single" if self._initial_state["single"] else "multiple"), + variant="primary", + size="sm", + min_width=1, + ) + self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1) + self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) + + return col + + def _wire_events(self): + # Selection: mirror into state and keep gallery.selected_index in sync + self.gallery.select( + self._on_select, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.upload( + self._on_upload, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.upload( + self._on_gallery_change, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Add via UploadButton + self.upload_btn.upload( + self._on_add, + inputs=[self.upload_btn, self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Remove selected + self.btn_remove.click( + self._on_remove, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Reorder using selected index, keep same item selected + self.btn_left.click( + lambda st, gallery: self._on_move(-1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + self.btn_right.click( + lambda st, gallery: self._on_move(+1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Clear all + self.btn_clear.click( + self._on_clear, + inputs=[self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # ---------------- public API ---------------- + + def set_one_image_mode(self, enabled: bool = True): + """Toggle single-image mode at runtime.""" + return ( + self._on_toggle_single, + [gr.State(enabled), self.state], + [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state], + ) + + def get_toggable_elements(self): + return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state] + +# import gradio as gr + +# with gr.Blocks() as demo: +# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8) +# amg.mount() +# g = amg.gallery +# # buttons to switch modes live (optional) +# def process(g): +# pass +# with gr.Row(): +# gr.Button("toto").click(process, g) +# gr.Button("ONE image").click(*amg.set_one_image_mode(True)) +# gr.Button("MULTI image").click(*amg.set_one_image_mode(False)) + +# demo.launch() diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py new file mode 100644 index 0000000..3165e7b --- /dev/null +++ b/shared/inpainting/lanpaint.py @@ -0,0 +1,240 @@ +import torch +from .utils import * +from functools import partial + +# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) + +def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape + + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + +def _unpack_latents(latents, height, width, vae_scale_factor=8): + batch_size, num_patches, channels = latents.shape + + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + +class LanPaint(): + def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): + self.n_steps = NSteps + self.chara_lamb = Lambda + self.IS_FLUX = IS_FLUX + self.IS_FLOW = IS_FLOW + self.step_size = StepSize + self.friction = Friction + self.chara_beta = Beta + self.img_dim_size = None + def add_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (None,) * (self.img_dim_size-1) + return array[index] + def remove_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (0,) * (self.img_dim_size-1) + return array[index] + def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): + latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) + noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) + x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) + latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.img_dim_size = len(x.shape) + self.latent_image = latent_image + self.noise = noise + if n_steps is None: + n_steps = self.n_steps + out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) + out = _pack_latents(out) + return out + def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): + if IS_FLUX: + cfg_BIG = 1.0 + + def double_denoise(latents, t): + latents = _pack_latents(latents) + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None, None + predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) + if true_cfg_scale == cfg_BIG: + predict_big = predict_std + else: + predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) + predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) + return predict_std, predict_big + + if len(sigma.shape) == 0: + sigma = torch.tensor([sigma.item()]) + latent_mask = 1 - latent_mask + if IS_FLUX or IS_FLOW: + Flow_t = sigma + abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) + VE_Sigma = Flow_t / (1 - Flow_t) + #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) + else: + VE_Sigma = sigma + abt = 1/( 1+VE_Sigma**2 ) + Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) + # VE_Sigma, abt, Flow_t = current_times + current_times = (VE_Sigma, abt, Flow_t) + + step_size = self.step_size * (1 - abt) + step_size = self.add_none_dims(step_size) + # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values + # This is the replace step + # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask + + noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma + x = x * (1 - latent_mask) + noisy_image * latent_mask + + if IS_FLUX or IS_FLOW: + x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + + ############ LanPaint Iterations Start ############### + # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. + args = None + for i in range(n_steps): + score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) + if score_func is None: return None + x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) + if IS_FLUX or IS_FLOW: + x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + ############ LanPaint Iterations End ############### + # out is x_0 + # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) + # out = out * (1-latent_mask) + self.latent_image * latent_mask + # return out + return x + + def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): + + lamb = self.chara_lamb + if self.IS_FLUX or self.IS_FLOW: + # compute t for flow model, with a small epsilon compensating for numerical error. + x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) + if x_0 is None: return None + else: + x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) + if x_0 is None: return None + + score_x = -(x_t - x_0) + score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) + return score_x * (1 - mask) + score_y * mask + def sigma_x(self, abt): + # the time scale for the x_t update + return abt**0 + def sigma_y(self, abt): + beta = self.chara_beta * abt ** 0 + return beta + + def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): + # prepare the step size and time parameters + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) + sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes + # print('mask',mask.device) + if torch.mean(dtx) <= 0.: + return x_t, args + # ------------------------------------------------------------------------- + # Compute the Langevin dynamics update in variance perserving notation + # ------------------------------------------------------------------------- + #x0 = self.x0_evalutation(x_t, score, sigma, args) + #C = abt**0.5 * x0 / (1-abt) + A = A_x * (1-mask) + A_y * mask + D = D_x * (1-mask) + D_y * mask + dt = dtx * (1-mask) + dty * mask + Gamma = Gamma_x * (1-mask) + Gamma_y * mask + + + def Coef_C(x_t): + x0 = self.x0_evalutation(x_t, score, sigma, args) + C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t + return C + def advance_time(x_t, v, dt, Gamma, A, C, D): + dtype = x_t.dtype + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + osc = StochasticHarmonicOscillator(Gamma, A, C, D ) + x_t, v = osc.dynamics(x_t, v, dt ) + x_t = x_t.to(dtype) + v = v.to(dtype) + return x_t, v + if args is None: + #v = torch.zeros_like(x_t) + v = None + C = Coef_C(x_t) + #print(torch.squeeze(dtx), torch.squeeze(dty)) + x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) + else: + v, C = args + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C_new = Coef_C(x_t) + v = v + Gamma**0.5 * ( C_new - C) *dt + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C = C_new + + return x_t, (v, C) + + def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): + # ------------------------------------------------------------------------- + # Unpack current times parameters (sigma and abt) + sigma, abt, flow_t = current_times + sigma = self.add_none_dims(sigma) + abt = self.add_none_dims(abt) + # Compute time step (dtx, dty) for x and y branches. + dtx = 2 * step_size * sigma_x + dty = 2 * step_size * sigma_y + + # ------------------------------------------------------------------------- + # Define friction parameter Gamma_hat for each branch. + # Using dtx**0 provides a tensor of the proper device/dtype. + + Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 + Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 + #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) + # adjust dt to match denoise-addnoise steps sizes + Gamma_hat_x /= 2. + Gamma_hat_y /= 2. + A_t_x = (1) / ( 1 - abt ) * dtx / 2 + A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 + + + A_x = A_t_x / (dtx/2) + A_y = A_t_y / (dty/2) + Gamma_x = Gamma_hat_x / (dtx/2) + Gamma_y = Gamma_hat_y / (dty/2) + + #D_x = (2 * (1 + sigma**2) )**0.5 + #D_y = (2 * (1 + sigma**2) )**0.5 + D_x = (2 * abt**0 )**0.5 + D_y = (2 * abt**0 )**0.5 + return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y + + + + def x0_evalutation(self, x_t, score, sigma, args): + x0 = x_t + score(x_t) + return x0 \ No newline at end of file diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py new file mode 100644 index 0000000..c017ab0 --- /dev/null +++ b/shared/inpainting/utils.py @@ -0,0 +1,301 @@ +import torch +def epxm1_x(x): + # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. + result = torch.special.expm1(x) / x + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x) < 1e-2 + result = torch.where(mask, 1 + x/2. + x**2 / 6., result) + return result +def epxm1mx_x2(x): + # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x) / x**2 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**2) < 1e-2 + result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) + return result + +def expm1mxmhx2_x3(x): + # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**3) < 1e-2 + result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) + return result + +def exp_1mcosh_GD(gamma_t, delta): + """ + Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) + numerator = torch.where(is_positive, numerator_pos, numerator_neg) + result = numerator / (delta * gamma_t**2 ) + # Handle NaN/inf cases + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + # Handle numerical instability for small delta + mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 + taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) + result = torch.where(mask, taylor, result) + return result + +def exp_sinh_GsqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + denominator_pos = gamma_t_sqrt_delta + result_pos = numerator_pos / gamma_t_sqrt_delta + result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) + + # Taylor expansion for small gamma_t_sqrt_delta + mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 + taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) + result_pos = torch.where(mask, taylor, result_pos) + + # Handle negative delta + result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) + result = torch.where(is_positive, result_pos, result_neg) + return result + +def exp_cosh(gamma_t, delta): + """ + Compute e^(-Γt) * cosh(Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result + return result +def exp_sinh_sqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / √Δ + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + Returns: + Result of the computation with numerical stability handling + """ + exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + result = gamma_t * exp_sinh_GsqrtD_result + return result + + + +def zeta1(gamma_t, delta): + # Compute hyperbolic terms and exponential + half_gamma_t = gamma_t / 2 + exp_cosh_term = exp_cosh(half_gamma_t, delta) + exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) + + + # Main computation + numerator = 1 - (exp_cosh_term + exp_sinh_term) + denominator = gamma_t * (1 - delta) / 4 + result = 1 - numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small x (similar to your epxm1Dx approach) + mask = torch.abs(denominator) < 5e-3 + term1 = epxm1_x(-gamma_t) + term2 = epxm1mx_x2(-gamma_t) + term3 = expm1mxmhx2_x3(-gamma_t) + taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 + result = torch.where(mask, taylor, result) + + return result + +def exp_cosh_minus_terms(gamma_t, delta): + """ + Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_term = torch.exp(-gamma_t) + # Compute individual terms + exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term + exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term + + #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + # Main computation + numerator = exp_cosh_term - exp_cosh_delta_term + denominator = gamma_t * (1 - delta) + + result = numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small gamma_t and delta near 1 + mask = (torch.abs(denominator) < 1e-1) + exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) + taylor = ( + gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) + - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) + ) + result = torch.where(mask, taylor, result) + + return result + + +def zeta2(gamma_t, delta): + half_gamma_t = gamma_t / 2 + return exp_sinh_GsqrtD(half_gamma_t, delta) + +def sig11(gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + + +def Zcoefs(gamma_t, delta): + Zeta1 = zeta1(gamma_t, delta) + Zeta2 = zeta2(gamma_t, delta) + + sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 + amplitude = torch.sqrt(sq_total) + Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude + Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 + #cterm = exp_cosh_minus_terms(gamma_t, delta) + #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) + #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) + Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) + + return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude + +def Zcoefs_asymp(gamma_t, delta): + A_t = (gamma_t * (1 - delta) )/4 + return epxm1_x(- 2 * A_t) + +class StochasticHarmonicOscillator: + """ + Simulates a stochastic harmonic oscillator governed by the equations: + dy(t) = q(t) dt + dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt + + Also define v(t) = q(t) / √Γ, which is numerically more stable. + + Where: + y(t) - Position variable + q(t) - Velocity variable + Γ - Damping coefficient + A - Harmonic potential strength + C - Constant force term + D - Noise amplitude + dw(t) - Wiener process (Brownian motion) + """ + def __init__(self, Gamma, A, C, D): + self.Gamma = Gamma + self.A = A + self.C = C + self.D = D + self.Delta = 1 - 4 * A / Gamma + def sig11(self, gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + def sig22(self, gamma_t, delta): + return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) + def dynamics(self, y0, v0, t): + """ + Calculates the position and velocity variables at time t. + + Parameters: + y0 (float): Initial position + v0 (float): Initial velocity v(0) = q(0) / √Γ + t (float): Time at which to evaluate the dynamics + Returns: + tuple: (y(t), v(t)) + """ + + dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 + Delta = self.Delta + dummyzero + Gamma_hat = self.Gamma * t + dummyzero + A = self.A + dummyzero + C = self.C + dummyzero + D = self.D + dummyzero + Gamma = self.Gamma + dummyzero + zeta_1 = zeta1( Gamma_hat, Delta) + zeta_2 = zeta2( Gamma_hat, Delta) + EE = 1 - Gamma_hat * zeta_2 + + if v0 is None: + v0 = torch.randn_like(y0) * D / 2 ** 0.5 + #v0 = (C - A * y0)/Gamma**0.5 + + # Calculate mean position and velocity + term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t + y_mean = term1 + y0 + v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 + + cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) + cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 + cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) + + # sample new position and velocity with multivariate normal distribution + + batch_shape = y0.shape + cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + cov_matrix[..., 0, 0] = cov_yy + cov_matrix[..., 0, 1] = cov_yv + cov_matrix[..., 1, 0] = cov_yv # symmetric + cov_matrix[..., 1, 1] = cov_vv + + + + # Compute the Cholesky decomposition to get scale_tril + #scale_tril = torch.linalg.cholesky(cov_matrix) + scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + tol = 1e-8 + cov_yy = torch.clamp( cov_yy, min = tol ) + sd_yy = torch.sqrt( cov_yy ) + inv_sd_yy = 1/(sd_yy) + + scale_tril[..., 0, 0] = sd_yy + scale_tril[..., 0, 1] = 0. + scale_tril[..., 1, 0] = cov_yv * inv_sd_yy + scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 + # check if it matches torch.linalg. + #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) + # Sample correlated noise from multivariate normal + mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) + mean[..., 0] = y_mean + mean[..., 1] = v_mean + new_yv = torch.distributions.MultivariateNormal( + loc=mean, + scale_tril=scale_tril + ).sample() + + return new_yv[...,0], new_yv[...,1] \ No newline at end of file diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py index b24530d..224cf34 100644 --- a/shared/utils/audio_video.py +++ b/shared/utils/audio_video.py @@ -232,6 +232,9 @@ def save_video(tensor, retry=5): """Save tensor as video with configurable codec and container options.""" + if torch.is_tensor(tensor) and len(tensor.shape) == 4: + tensor = tensor.unsqueeze(0) + suffix = f'.{container}' cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file if not cache_file.endswith(suffix): diff --git a/shared/utils/download.py b/shared/utils/download.py new file mode 100644 index 0000000..ed035c0 --- /dev/null +++ b/shared/utils/download.py @@ -0,0 +1,110 @@ +import sys, time + +# Global variables to track download progress +_start_time = None +_last_time = None +_last_downloaded = 0 +_speed_history = [] +_update_interval = 0.5 # Update speed every 0.5 seconds + +def progress_hook(block_num, block_size, total_size, filename=None): + """ + Simple progress bar hook for urlretrieve + + Args: + block_num: Number of blocks downloaded so far + block_size: Size of each block in bytes + total_size: Total size of the file in bytes + filename: Name of the file being downloaded (optional) + """ + global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval + + current_time = time.time() + downloaded = block_num * block_size + + # Initialize timing on first call + if _start_time is None or block_num == 0: + _start_time = current_time + _last_time = current_time + _last_downloaded = 0 + _speed_history = [] + + # Calculate download speed only at specified intervals + speed = 0 + if current_time - _last_time >= _update_interval: + if _last_time > 0: + current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) + _speed_history.append(current_speed) + # Keep only last 5 speed measurements for smoothing + if len(_speed_history) > 5: + _speed_history.pop(0) + # Average the recent speeds for smoother display + speed = sum(_speed_history) / len(_speed_history) + + _last_time = current_time + _last_downloaded = downloaded + elif _speed_history: + # Use the last calculated average speed + speed = sum(_speed_history) / len(_speed_history) + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + file_display = filename if filename else "Unknown file" + + if total_size <= 0: + # If total size is unknown, show downloaded bytes + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(80)) + sys.stdout.flush() + return + + downloaded = block_num * block_size + percent = min(100, (downloaded / total_size) * 100) + + # Create progress bar (40 characters wide to leave room for other info) + bar_length = 40 + filled = int(bar_length * percent / 100) + bar = '█' * filled + '░' * (bar_length - filled) + + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + + # Display progress with filename first + line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(100)) + sys.stdout.flush() + + # Print newline when complete + if percent >= 100: + print() + +# Wrapper function to include filename in progress hook +def create_progress_hook(filename): + """Creates a progress hook with the filename included""" + global _start_time, _last_time, _last_downloaded, _speed_history + # Reset timing variables for new download + _start_time = None + _last_time = None + _last_downloaded = 0 + _speed_history = [] + + def hook(block_num, block_size, total_size): + return progress_hook(block_num, block_size, total_size, filename) + return hook + + diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 58cc9a9..e0cec6a 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -99,7 +99,7 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me return loras_list_mult_choices_nums, slists_dict, "" -def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ): +def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None): from mmgp import offload sz = len(slists_dict["phase1"]) slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] @@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st -def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): +def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): + total_num_steps = len(timesteps) model_switch_step = model_switch_step2 = None for i, t in enumerate(timesteps): if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i diff --git a/shared/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py index 3c682e6..3fe02cc 100644 --- a/shared/utils/qwen_vl_utils.py +++ b/shared/utils/qwen_vl_utils.py @@ -9,7 +9,6 @@ import os import sys import time import warnings -from functools import lru_cache from io import BytesIO import requests @@ -257,7 +256,6 @@ VIDEO_READER_BACKENDS = { FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) -@lru_cache(maxsize=1) def get_video_reader_backend() -> str: if FORCE_QWENVL_VIDEO_READER is not None: video_reader_backend = FORCE_QWENVL_VIDEO_READER diff --git a/shared/utils/utils.py b/shared/utils/utils.py index a55807a..00178aa 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os import os.path as osp @@ -18,11 +17,12 @@ import os import tempfile import subprocess import json - +from functools import lru_cache +os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from PIL import Image - +video_info_cache = [] def seed_everything(seed: int): random.seed(seed) np.random.seed(seed) @@ -32,6 +32,14 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -77,7 +85,9 @@ def truncate_for_filesystem(s, max_bytes=255): else: r = m - 1 return s[:l] +@lru_cache(maxsize=100) def get_video_info(video_path): + global video_info_cache import cv2 cap = cv2.VideoCapture(video_path) @@ -92,7 +102,7 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) @@ -100,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + fps = round(cap.get(cv2.CAP_PROP_FPS)) + if target_fps is not None: + frame_no = round(target_fps * frame_no /fps) + # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: @@ -173,9 +186,15 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = -1): - t = t[:, frame_no] if frame_no >= 0 else t - return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) +def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): + if len(t.shape) == 4: + t = t[:, frame_no] + if t.shape[0]== 1: + t = t.expand(3,-1,-1) + if mask_levels: + return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + else: + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) @@ -186,6 +205,14 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) return frame_height, frame_width +def rgb_bw_to_rgba_mask(img, thresh=127): + a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha + out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent + out.putalpha(a) # white where alpha=255 + return out + + + def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8): outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100)) @@ -205,30 +232,64 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width return height, width, margin_top, margin_left -def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): - if fit_into_canvas == None: +def rescale_and_crop(img, w, h): + ow, oh = img.size + target_ratio = w / h + orig_ratio = ow / oh + + if orig_ratio > target_ratio: + # Crop width first + nw = int(oh * target_ratio) + img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh)) + else: + # Crop height first + nh = int(ow / target_ratio) + img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2)) + + return img.resize((w, h), Image.LANCZOS) + +def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): + if fit_into_canvas == None or fit_into_canvas == 2: # return image_height, image_width return canvas_height, canvas_width - if fit_into_canvas: + if fit_into_canvas == 1: scale1 = min(canvas_height / image_height, canvas_width / image_width) scale2 = min(canvas_width / image_height, canvas_height / image_width) scale = max(scale1, scale2) - else: + else: #0 or #2 (crop) scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) new_height = round( image_height * scale / block_size) * block_size new_width = round( image_width * scale / block_size) * block_size return new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): +def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16): + if fit_crop: + image = rescale_and_crop(image, canvas_width, canvas_height) + new_width, new_height = image.size + else: + image_width, image_height = image.size + new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size ) + image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + return image, new_height, new_width + +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): if rm_background: session = new_session() output_list =[] + output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size - - if fit_into_canvas: + resized_mask = None + if any_background_ref == 1 and i==0 or any_background_ref == 2: + if outpainting_dims is not None and background_ref_outpainted: + resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) + elif img.size != (budget_width, budget_height): + resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) + else: + resized_image =img + elif fit_into_canvas == 1: white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 scale = min(budget_height / height, budget_width / width) new_height = int(height * scale) @@ -240,152 +301,112 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg resized_image = Image.fromarray(white_canvas) else: scale = (budget_height * budget_width / (height * width))**(1/2) - new_height = int( round(height * scale / 16) * 16) - new_width = int( round(width * scale / 16) * 16) + new_height = int( round(height * scale / block_size) * block_size) + new_width = int( round(width * scale / block_size) * block_size) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) - if rm_background and not (ignore_first and i == 0) : + if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, - return output_list + if return_tensor: + output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) + else: + output_list.append(resized_image) + output_mask_list.append(resized_mask) + return output_list, output_mask_list +def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): + from shared.utils.utils import save_image + inpaint_color = canvas_tf_bg / 127.5 - 1 - - -def str2bool(v): - """ - Convert a string to a boolean. - - Supported true values: 'yes', 'true', 't', 'y', '1' - Supported false values: 'no', 'false', 'f', 'n', '0' - - Args: - v (str): String to convert. - - Returns: - bool: Converted boolean value. - - Raises: - argparse.ArgumentTypeError: If the value cannot be converted to boolean. - """ - if isinstance(v, bool): - return v - v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): - return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): - return False + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size and outpainting_dims == None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + canvas = torch.zeros_like(ref_img[:1]) if return_mask else None else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') - - -import sys, time - -# Global variables to track download progress -_start_time = None -_last_time = None -_last_downloaded = 0 -_speed_history = [] -_update_interval = 0.5 # Update speed every 0.5 seconds - -def progress_hook(block_num, block_size, total_size, filename=None): - """ - Simple progress bar hook for urlretrieve - - Args: - block_num: Number of blocks downloaded so far - block_size: Size of each block in bytes - total_size: Total size of the file in bytes - filename: Name of the file being downloaded (optional) - """ - global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval - - current_time = time.time() - downloaded = block_num * block_size - - # Initialize timing on first call - if _start_time is None or block_num == 0: - _start_time = current_time - _last_time = current_time - _last_downloaded = 0 - _speed_history = [] - - # Calculate download speed only at specified intervals - speed = 0 - if current_time - _last_time >= _update_interval: - if _last_time > 0: - current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) - _speed_history.append(current_speed) - # Keep only last 5 speed measurements for smoothing - if len(_speed_history) > 5: - _speed_history.pop(0) - # Average the recent speeds for smoother display - speed = sum(_speed_history) / len(_speed_history) - - _last_time = current_time - _last_downloaded = downloaded - elif _speed_history: - # Use the last calculated average speed - speed = sum(_speed_history) / len(_speed_history) - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - file_display = filename if filename else "Unknown file" - - if total_size <= 0: - # If total size is unknown, show downloaded bytes - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(80)) - sys.stdout.flush() - return - - downloaded = block_num * block_size - percent = min(100, (downloaded / total_size) * 100) - - # Create progress bar (40 characters wide to leave room for other info) - bar_length = 40 - filled = int(bar_length * percent / 100) - bar = '█' * filled + '░' * (bar_length - filled) - - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - - # Display progress with filename first - line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(100)) - sys.stdout.flush() - - # Print newline when complete - if percent >= 100: - print() - -# Wrapper function to include filename in progress hook -def create_progress_hook(filename): - """Creates a progress hook with the filename included""" - global _start_time, _last_time, _last_downloaded, _speed_history - # Reset timing variables for new download - _start_time = None - _last_time = None - _last_downloaded = 0 - _speed_history = [] - - def hook(block_num, block_size, total_size): - return progress_hook(block_num, block_size, total_size, filename) - return hook + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + else: + canvas_height, canvas_width = image_size + if full_frame: + new_height = canvas_height + new_width = canvas_width + top = left = 0 + else: + # if fill_max and (canvas_height - new_height) < 16: + # new_height = canvas_height + # if fill_max and (canvas_width - new_width) < 16: + # new_width = canvas_width + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + if return_image: + return convert_tensor_to_image(ref_img), canvas + + return ref_img.to(device), canvas + +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): + src_videos, src_masks = [], [] + inpaint_color_compressed = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 + for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): + src_video, src_mask = cur_video_guide, cur_video_mask + if pre_video_guide is not None: + src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) + if any_mask: + src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1) + + if any_guide_padding: + if src_video is None: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) + elif src_video.shape[1] < current_video_length: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) + elif src_video is not None: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + + if any_mask and src_video is not None: + if src_mask is None: + src_mask = torch.ones_like(src_video[:1]) + elif src_mask.shape[1] < src_video.shape[1]: + src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + src_mask = src_mask[:, :src_video.shape[1]] + + if src_video is not None : + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color_compressed + if any_mask: src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) + if any_mask: src_mask[:, pos:pos+1] = msk + src_videos.append(src_video) + src_masks.append(src_mask) + return src_videos, src_masks diff --git a/wgp.py b/wgp.py index 8cca2ad..e00660c 100644 --- a/wgp.py +++ b/wgp.py @@ -1,4 +1,9 @@ import os +os.environ["GRADIO_LANG"] = "en" +# # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding +# os.environ["TORCH_LOGS"]= "recompiles" +import torch._logging as tlog +# tlog.set_logs(recompiles=True, guards=True, graph_breaks=True) import time import sys import threading @@ -17,7 +22,9 @@ import numpy as np import importlib from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers -from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame +from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask +from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +from shared.utils.utils import has_video_file_extension, has_image_file_extension from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -45,6 +52,7 @@ logging.set_verbosity_error from preprocessing.matanyone import app as matanyone_app from tqdm import tqdm import requests +from shared.gradio.gallery import AdvancedMediaGallery # import torch._dynamo as dynamo # dynamo.config.recompile_limit = 2000 # default is 256 @@ -54,9 +62,9 @@ global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.10" -WanGP_version = "8.2" -settings_version = 2.27 +target_mmgp_version = "3.6.0" +WanGP_version = "8.74" +settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -81,6 +89,7 @@ def clear_gen_cache(): def release_model(): global wan_model, offloadobj, reload_needed + wan_model = None clear_gen_cache() offload.shared_state if offloadobj is not None: @@ -180,10 +189,23 @@ def compute_sliding_window_no(current_video_length, sliding_window_size, discard return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames)) +def clean_image_list(gradio_list): + if not isinstance(gradio_list, list): gradio_list = [gradio_list] + gradio_list = [ tup[0] if isinstance(tup, tuple) else tup for tup in gradio_list ] + + if any( not isinstance(image, (Image.Image, str)) for image in gradio_list): return None + if any( isinstance(image, str) and not has_image_file_extension(image) for image in gradio_list): return None + gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ] + return gradio_list + + + def process_prompt_and_add_tasks(state, model_choice): - + def ret(): + return gr.update(), gr.update() + if state.get("validate_success",0) != 1: - return + ret() state["validate_success"] = 0 model_filename = state["model_filename"] @@ -200,10 +222,10 @@ def process_prompt_and_add_tasks(state, model_choice): if inputs == None: gr.Warning("Internal state error: Could not retrieve inputs for the model.") queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() model_def = get_model_def(model_type) model_handler = get_model_handler(model_type) - image_outputs = inputs["image_mode"] == 1 + image_outputs = inputs["image_mode"] > 0 any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -229,7 +251,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0: gr.Info("Temporal Upsampling can not be used with an Image") - return + return ret() film_grain_intensity = inputs.get("film_grain_intensity",0) film_grain_saturation = inputs.get("film_grain_saturation",0.5) # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] @@ -245,7 +267,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if audio_source is None: gr.Info("You must provide a custom Audio") - return + return ret() prompt += ["Custom Audio"] repeat_generation == 1 @@ -255,32 +277,32 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("You must choose at least one Remux Method") else: gr.Info("You must choose at least one Post Processing Method") - return + return ret() inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) gen["prompts_max"] = 1 + gen.get("prompts_max",0) state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + return ret() if hasattr(model_handler, "validate_generative_settings"): error = model_handler.validate_generative_settings(model_type, model_def, inputs) if error is not None and len(error) > 0: gr.Info(error) - return + return ret() if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0: gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time") - return + return ret() prompt = inputs["prompt"] if len(prompt) ==0: gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() prompt, errors = prompt_parser.process_template(prompt) if len(errors) > 0: gr.Info("Error processing prompt template: " + errors) - return + return ret() model_filename = get_model_filename(model_type) prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] @@ -288,7 +310,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() resolution = inputs["resolution"] width, height = resolution.split("x") @@ -330,30 +352,36 @@ def process_prompt_and_add_tasks(state, model_choice): model_switch_phase = inputs["model_switch_phase"] switch_threshold = inputs["switch_threshold"] switch_threshold2 = inputs["switch_threshold2"] - + multi_prompts_gen_type = inputs["multi_prompts_gen_type"] + video_guide_outpainting = inputs["video_guide_outpainting"] + + outpainting_dims = get_outpainting_dims(video_guide_outpainting) + + if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): + gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") if len(loras_multipliers) > 0: _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) if len(errors) > 0: gr.Info(f"Error parsing Loras Multipliers: {errors}") - return + return ret() if guidance_phases == 3: if switch_threshold < switch_threshold2: gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.") - return + return ret() else: model_switch_phase = 1 if not any_steps_skipping: skip_steps_cache_type = "" if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") - return + return ret() if skip_steps_cache_type == "mag": if num_inference_steps > 50: gr.Info("Mag Cache maximum number of steps is 50") - return + return ret() - if image_mode == 1: + if image_mode > 0: audio_prompt_type = "" if "B" in audio_prompt_type or "X" in audio_prompt_type: @@ -361,49 +389,53 @@ def process_prompt_and_add_tasks(state, model_choice): speakers_bboxes, error = parse_speakers_locations(speakers_locations) if len(error) > 0: gr.Info(error) - return + return ret() if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: if len(frames_positions.strip()) > 0: - positions = frames_positions.split(" ") + positions = frames_positions.replace(","," ").split(" ") for pos_str in positions: - if not is_integer(pos_str): - gr.Info(f"Invalid Frame Position '{pos_str}'") - return - pos = int(pos_str) - if pos <1 or pos > max_source_video_frames: - gr.Info(f"Invalid Frame Position Value'{pos_str}'") - return + if not pos_str in ["L", "l"] and len(pos_str)>0: + if not is_integer(pos_str): + gr.Info(f"Invalid Frame Position '{pos_str}'") + return ret() + pos = int(pos_str) + if pos <1 or pos > max_source_video_frames: + gr.Info(f"Invalid Frame Position Value'{pos_str}'") + return ret() else: frames_positions = None if audio_source is not None and MMAudio_setting != 0: gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time") - return + return ret() if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: gr.Info("The number of frames to keep must be a non null integer") - return + return ret() else: keep_frames_video_source = "" + if image_outputs: + image_prompt_type = image_prompt_type.replace("V", "").replace("L", "") + if "V" in image_prompt_type: if video_source == None: gr.Info("You must provide a Source Video file to continue") - return + return ret() else: video_source = None if "A" in audio_prompt_type: if audio_guide == None: gr.Info("You must provide an Audio Source") - return + return ret() if "B" in audio_prompt_type: if audio_guide2 == None: gr.Info("You must provide a second Audio Source") - return + return ret() else: audio_guide2 = None else: @@ -417,24 +449,23 @@ def process_prompt_and_add_tasks(state, model_choice): if model_def.get("one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide an Image Reference") - return + return ret() if len(image_refs) > 1: gr.Info("Only one Image Reference (a person) is supported for the moment by this model") - return + return ret() if model_def.get("at_least_one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide at least one Image Reference") - return + return ret() if "I" in video_prompt_type: if image_refs == None or len(image_refs) == 0: - gr.Info("You must provide at least one Refererence Image") - return - if any(isinstance(image[0], str) for image in image_refs) : + gr.Info("You must provide at least one Reference Image") + return ret() + image_refs = clean_image_list(image_refs) + if image_refs == None : gr.Info("A Reference Image should be an Image") - return - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] + return ret() else: image_refs = None @@ -442,35 +473,36 @@ def process_prompt_and_add_tasks(state, model_choice): if image_outputs: if image_guide is None: gr.Info("You must provide a Control Image") - return + return ret() else: if video_guide is None: gr.Info("You must provide a Control Video") - return + return ret() if "A" in video_prompt_type and not "U" in video_prompt_type: if image_outputs: if image_mask is None: gr.Info("You must provide a Image Mask") - return + return ret() else: if video_mask is None: gr.Info("You must provide a Video Mask") - return + return ret() else: video_mask = None image_mask = None if "G" in video_prompt_type: - gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") + if denoising_strength < 1.: + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ") else: denoising_strength = 1.0 if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: gr.Info("Keep Frames for Control Video is not supported with LTX Video") - return + return ret() _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) if len(error) > 0: gr.Info(f"Invalid Keep Frames property: {error}") - return + return ret() else: video_guide = None image_guide = None @@ -490,69 +522,75 @@ def process_prompt_and_add_tasks(state, model_choice): if "S" in image_prompt_type: if image_start == None or isinstance(image_start, list) and len(image_start) == 0: gr.Info("You must provide a Start Image") - return - if not isinstance(image_start, list): - image_start = [image_start] - if not all( not isinstance(img[0], str) for img in image_start) : + return ret() + image_start = clean_image_list(image_start) + if image_start == None : gr.Info("Start Image should be an Image") - return - image_start = [ convert_image(tup[0]) for tup in image_start ] + return ret() + if multi_prompts_gen_type == 1 and len(image_start) > 1: + gr.Info("Only one Start Image is supported") + return ret() else: image_start = None + if not any_letters(image_prompt_type, "SVL"): + image_prompt_type = image_prompt_type.replace("E", "") if "E" in image_prompt_type: if image_end == None or isinstance(image_end, list) and len(image_end) == 0: gr.Info("You must provide an End Image") - return - if not isinstance(image_end, list): - image_end = [image_end] - if not all( not isinstance(img[0], str) for img in image_end) : + return ret() + image_end = clean_image_list(image_end) + if image_end == None : gr.Info("End Image should be an Image") - return - if len(image_start) != len(image_end): - gr.Info("The number of Start and End Images should be the same ") - return - image_end = [ convert_image(tup[0]) for tup in image_end ] + return ret() + if multi_prompts_gen_type == 0: + if video_source is not None: + if len(image_end)> 1: + gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image") + return ret() + elif len(image_start or []) != len(image_end or []): + gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'") + return ret() else: image_end = None if test_any_sliding_window(model_type) and image_mode == 0: if video_length > sliding_window_size: - full_video_length = video_length if video_source is None else video_length + sliding_window_overlap + if model_type in ["t2v"] and not "G" in video_prompt_type : + gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.") + return ret() + full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") - if "recam" in model_filename: - if video_source == None: - gr.Info("You must provide a Source Video") - return - - frames = get_resampled_video(video_source, 0, 81, get_computed_fps(force_fps, model_type , video_guide, video_source )) + if video_guide == None: + gr.Info("You must provide a Control Video") + return ret() + computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source ) + frames = get_resampled_video(video_guide, 0, 81, computed_fps) if len(frames)<81: - gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done") - return - - + gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done") + return ret() if "hunyuan_custom_custom_edit" in model_filename: if len(keep_frames_video_guide) > 0: gr.Info("Filtering Frames with this model is not supported") - return + return ret() if inputs["multi_prompts_gen_type"] != 0: if image_start != None and len(image_start) > 1: gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") - return + return ret() - if image_end != None and len(image_end) > 1: - gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") - return + # if image_end != None and len(image_end) > 1: + # gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") + # return override_inputs = { "image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None, - "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None, + "image_end": image_end, #[0] if image_end !=None and len(image_end) > 0 else None, "image_refs": image_refs, "audio_guide": audio_guide, "audio_guide2": audio_guide2, @@ -592,7 +630,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(prompts) >= len(image_start): if len(prompts) % len(image_start) != 0: gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - return + return ret() rep = len(prompts) // len(image_start) new_image_start = [] new_image_end = [] @@ -606,7 +644,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if len(image_start) % len(prompts) !=0: gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - return + return ret() rep = len(image_start) // len(prompts) new_prompts = [] for i, _ in enumerate(image_start): @@ -628,19 +666,23 @@ def process_prompt_and_add_tasks(state, model_choice): override_inputs["prompt"] = single_prompt inputs.update(override_inputs) add_video_task(**inputs) + new_prompts_count = len(prompts) else: + new_prompts_count = 1 override_inputs["prompt"] = "\n".join(prompts) inputs.update(override_inputs) add_video_task(**inputs) - - gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) + new_prompts_count += gen.get("prompts_max",0) + gen["prompts_max"] = new_prompts_count state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + first_time_in_queue = state.get("first_time_in_queue", True) + state["first_time_in_queue"] = True + return update_queue_data(queue, first_time_in_queue), gr.update(open=True) if new_prompts_count > 1 else gr.update() def get_preview_images(inputs): - inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] - labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] + inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] + labels = ["Start Image", "Video Source", "End Image", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] start_image_data = None start_image_labels = [] end_image_data = None @@ -690,7 +732,6 @@ def add_video_task(**inputs): "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None }) - return update_queue_data(queue) def update_task_thumbnails(task, inputs): start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs) @@ -1307,14 +1348,12 @@ def get_queue_table(queue): "✖" ]) return data -def update_queue_data(queue): +def update_queue_data(queue, first_time_in_queue =False): update_global_queue_ref(queue) data = get_queue_table(queue) - if len(data) == 0: - return gr.DataFrame(visible=False) - else: - return gr.DataFrame(value=data, visible= True) + return gr.DataFrame(value=data) + def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1349,6 +1388,20 @@ def _parse_args(): help="save proprocessed audio track with extract speakers for debugging or editing" ) + parser.add_argument( + "--debug-gen-form", + action="store_true", + help="View form generation / refresh time" + ) + + parser.add_argument( + "--vram-safety-coefficient", + type=float, + default=0.8, + help="max VRAM (between 0 and 1) that should be allocated to preloaded models" + ) + + parser.add_argument( "--share", action="store_true", @@ -1890,7 +1943,8 @@ def get_model_min_frames_and_step(model_type): mode_def = get_model_def(model_type) frames_minimum = mode_def.get("frames_minimum", 5) frames_steps = mode_def.get("frames_steps", 4) - return frames_minimum, frames_steps + latent_size = mode_def.get("latent_size", frames_steps) + return frames_minimum, frames_steps, latent_size def get_model_fps(model_type): mode_def = get_model_def(model_type) @@ -1954,8 +2008,10 @@ def get_model_recursive_prop(model_type, prop = "URLs", sub_prop_name = None, re raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]): - if module_type is not None: +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, URLs = None, stack=[]): + if URLs is not None: + pass + elif module_type is not None: base_model_type = get_base_model_type(model_type) # model_type_handler = model_types_handlers[base_model_type] # modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} @@ -1982,7 +2038,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + # choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + choices = URLs if len(quantization) == 0: quantization = "bf16" @@ -1992,13 +2049,13 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu raw_filename = choices[0] else: if quantization in ("int8", "fp8"): - sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name] + sub_choices = [ name for name in choices if quantization in os.path.basename(name) or quantization.upper() in os.path.basename(name)] else: - sub_choices = [ name for name in choices if "quanto" not in name] + sub_choices = [ name for name in choices if "quanto" not in os.path.basename(name)] if len(sub_choices) > 0: dtype_str = "fp16" if dtype == torch.float16 else "bf16" - new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name] + new_sub_choices = [ name for name in sub_choices if dtype_str in os.path.basename(name) or dtype_str.upper() in os.path.basename(name)] sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices raw_filename = sub_choices[0] else: @@ -2040,8 +2097,6 @@ def fix_settings(model_type, ui_defaults): if image_prompt_type != None : if not isinstance(image_prompt_type, str): image_prompt_type = "S" if image_prompt_type == 0 else "SE" - # if model_type == "flf2v_720p" and not "E" in image_prompt_type: - # image_prompt_type = "SE" if settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type @@ -2059,12 +2114,13 @@ def fix_settings(model_type, ui_defaults): audio_prompt_type ="A" ui_defaults["audio_prompt_type"] = audio_prompt_type + if settings_version < 2.35 and any_audio_track(base_model_type): + audio_prompt_type = audio_prompt_type or "" + audio_prompt_type += "V" + ui_defaults["audio_prompt_type"] = audio_prompt_type video_prompt_type = ui_defaults.get("video_prompt_type", "") - any_reference_image = model_def.get("reference_image", False) - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: - if not "I" in video_prompt_type: # workaround for settings corruption - video_prompt_type += "I" + if base_model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") @@ -2103,10 +2159,28 @@ def fix_settings(model_type, ui_defaults): del ui_defaults["tea_cache_start_step_perc"] ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if len(image_prompt_type) > 0: + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed","") + image_prompt_type = filter_letters(image_prompt_type, image_prompt_types_allowed) + ui_defaults["image_prompt_type"] = image_prompt_type + + video_prompt_type = ui_defaults.get("video_prompt_type", "") + image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", []) + if model_def.get("guide_custom_choices", None) is None: + if len(image_ref_choices_list)==0: + video_prompt_type = del_in_sequence(video_prompt_type, "IK") + else: + first_choice = image_ref_choices_list[0][1] + if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I" + if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K" + ui_defaults["video_prompt_type"] = video_prompt_type + model_handler = get_model_handler(base_model_type) if hasattr(model_handler, "fix_settings"): model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: @@ -2138,7 +2212,8 @@ def get_default_settings(model_type): "slg_switch": 0, "slg_layers": [9], "slg_start_perc": 10, - "slg_end_perc": 90 + "slg_end_perc": 90, + "audio_prompt_type": "V", } model_handler = get_model_handler(model_type) model_handler.update_default_settings(base_model_type, model_def, ui_defaults) @@ -2286,7 +2361,7 @@ if args.compile: #args.fastest or lock_ui_compile = True -def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True ): +def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True, module_source_no = 1): model_def = get_model_def(model_type) # To save module and quantized modules # 1) set Transformer Model Quantization Type to 16 bits @@ -2297,10 +2372,10 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod if model_def == None: return if is_module: url_key = "modules" - source_key = "module_source" + source_key = "module_source" if module_source_no <=1 else "module_source2" else: url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) - source_key = "source" + source_key = "source" if submodel_no <=1 else "source2" URLs= model_def.get(url_key, None) if URLs is None: return if isinstance(URLs, str): @@ -2315,6 +2390,9 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod print("Target Module files are missing") return URLs= URLs[0] + if isinstance(URLs, dict): + url_dict_key = "URLs" if module_source_no ==1 else "URLs2" + URLs = URLs[url_dict_key] for url in URLs: if "quanto" not in url and dtypestr in url: model_filename = os.path.basename(url) @@ -2348,8 +2426,12 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod elif not os.path.isfile(quanto_filename): offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter) print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.") - model_def[url_key][0].append(quanto_filename) - saved_finetune_def["model"][url_key][0].append(quanto_filename) + if isinstance(model_def[url_key][0],dict): + model_def[url_key][0][url_dict_key].append(quanto_filename) + saved_finetune_def["model"][url_key][0][url_dict_key].append(quanto_filename) + else: + model_def[url_key][0].append(quanto_filename) + saved_finetune_def["model"][url_key][0].append(quanto_filename) update_model_def = True if update_model_def: with open(finetune_file, "w", encoding="utf-8") as writer: @@ -2404,6 +2486,14 @@ def get_loras_preprocessor(transformer, model_type): return preprocessor_wrapper +def get_local_model_filename(model_filename): + if model_filename.startswith("http"): + local_model_filename = os.path.join("ckpts", os.path.basename(model_filename)) + else: + local_model_filename = model_filename + return local_model_filename + + def process_files_def(repoId, sourceFolderList, fileList): targetRoot = "ckpts/" @@ -2429,7 +2519,8 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1): +download_shared_done = False +def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2440,15 +2531,16 @@ def download_models(model_filename = None, model_type= None, module_type = None, from urllib.request import urlretrieve - from shared.utils.utils import create_progress_hook + from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], - ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], + ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors", "model.safetensors", "config.json"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], + ["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"], ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2471,6 +2563,9 @@ def download_models(model_filename = None, model_type= None, module_type = None, process_files_def(**enhancer_def) download_mmaudio() + global download_shared_done + download_shared_done = True + if model_filename is None: return def download_file(url,filename): @@ -2496,37 +2591,25 @@ def download_models(model_filename = None, model_type= None, module_type = None, base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - source = model_def.get("source", None) - module_source = model_def.get("module_source", None) + any_source = ("source2" if submodel_no ==2 else "source") in model_def + any_module_source = ("module_source2" if submodel_no ==2 else "module_source") in model_def model_type_handler = model_types_handlers[base_model_type] - - if source is not None and module_type is None or module_source is not None and module_type is not None: + local_model_filename = get_local_model_filename(model_filename) + + if any_source and not module_type or any_module_source and module_type: model_filename = None else: - if not os.path.isfile(model_filename): - if module_type is not None: - key_name = "modules" - URLs = module_type - if isinstance(module_type, str): - URLs = get_model_recursive_prop(module_type, key_name, sub_prop_name="_list", return_list= False) - else: - key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" - URLs = get_model_recursive_prop(model_type, key_name, return_list= False) - if isinstance(URLs, str): - raise Exception("Missing model " + URLs) - use_url = model_filename - for url in URLs: - if os.path.basename(model_filename) in url: - use_url = url - break - if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") - try: - download_file(use_url, model_filename) - except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") + if not os.path.isfile(local_model_filename): + url = model_filename + if not url.startswith("http"): + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") + try: + download_file(url, local_model_filename) + except Exception as e: + if os.path.isfile(local_model_filename): os.remove(local_model_filename) + raise Exception(f"'{url}' is invalid for Model '{local_model_filename}' : {str(e)}'") + if module_type: return model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) @@ -2552,6 +2635,7 @@ def download_models(model_filename = None, model_type= None, module_type = None, except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'") + if module_type: return model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) if not isinstance(model_files, list): model_files = [model_files] for one_repo in model_files: @@ -2738,7 +2822,8 @@ def load_models(model_type, override_profile = -1): model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! else: model_filename2 = None - modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = [get_model_recursive_prop(module, "modules", sub_prop_name ="_list", return_list= True) if isinstance(module, str) else module for module in modules ] if save_quantized and "quanto" in model_filename: save_quantized = False print("Need to provide a non quantized model to create a quantized model to be saved") @@ -2755,6 +2840,7 @@ def load_models(model_type, override_profile = -1): transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype perc_reserved_mem_max = args.perc_reserved_mem_max + vram_safety_coefficient = args.vram_safety_coefficient model_file_list = [model_filename] model_type_list = [model_type] module_type_list = [None] @@ -2762,27 +2848,40 @@ def load_models(model_type, override_profile = -1): if model_filename2 != None: model_file_list += [model_filename2] model_type_list += [model_type] - module_type_list += [None] + module_type_list += [False] model_submodel_no_list += [2] for module_type in modules: - model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) - model_type_list.append(model_type) - module_type_list.append(module_type) - model_submodel_no_list.append(0) + if isinstance(module_type,dict): + URLs1 = module_type.get("URLs", None) + if URLs1 is None: raise Exception(f"No URLs defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs1)) + URLs2 = module_type.get("URLs2", None) + if URLs2 is None: raise Exception(f"No URL2s defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs2)) + model_type_list += [model_type] * 2 + module_type_list += [True] * 2 + model_submodel_no_list += [1,2] + else: + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) + model_type_list.append(model_type) + module_type_list.append(True) + model_submodel_no_list.append(0) + local_model_file_list= [] for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list): download_models(filename, file_model_type, file_module_type, submodel_no) + local_model_file_list.append( get_local_model_filename(filename) ) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_type = None - for submodel_no, filename in zip(model_submodel_no_list, model_file_list): - if submodel_no>=1: + for module_type, filename in zip(module_type_list, local_model_file_list): + if module_type is None: print(f"Loading Model '{filename}' ...") else: print(f"Loading Module '{filename}' ...") wan_model, pipe = model_types_handlers[base_model_type].load_model( - model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, - dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + local_model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, + dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized, submodel_no_list = model_submodel_no_list, ) kwargs = {} profile = init_pipe(pipe, kwargs, override_profile) @@ -2791,7 +2890,7 @@ def load_models(model_type, override_profile = -1): loras_transformer = ["transformer"] if "transformer2" in pipe: loras_transformer += ["transformer2"] - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , vram_safety_coefficient = vram_safety_coefficient , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type @@ -2820,7 +2919,7 @@ def generate_header(model_type, compile, attention_mode): description_container = [""] get_model_name(model_type, description_container) - model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" + model_filename = os.path.basename(get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)) or "" description = description_container[0] header = f"
{description}
" overridden_attention = get_overridden_attention(model_type) @@ -3292,7 +3391,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): if not all_letters(src, pos): return False if neg is not None and any_letters(src, neg): return False return True - map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"} + image_outputs = configs.get("image_mode",0) > 0 + map_video_prompt = {"V" : "Control Image" if image_outputs else "Control Video", ("VA", "U") : "Mask Image" if image_outputs else "Mask Video", "I" : "Reference Images"} map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \ @@ -3343,6 +3443,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): if multiple_submodels: video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}" video_flow_shift = configs.get("flow_shift", None) + if image_outputs: video_flow_shift = None video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ @@ -3364,7 +3465,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): + if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") @@ -3431,10 +3532,17 @@ def convert_image(image): from PIL import ImageOps from typing import cast + if isinstance(image, str): + image = Image.open(image) image = image.convert('RGB') return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): + if isinstance(video_in, str) and has_image_file_extension(video_in): + video_in = Image.open(video_in) + if isinstance(video_in, Image.Image): + return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0) + from shared.utils.utils import resample import decord @@ -3450,6 +3558,48 @@ def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='t # print(f"frame nos: {frame_nos}") return frames_list +# def get_resampled_video(video_in, start_frame, max_frames, target_fps): +# from torchvision.io import VideoReader +# import torch +# from shared.utils.utils import resample + +# vr = VideoReader(video_in, "video") +# meta = vr.get_metadata()["video"] + +# fps = round(float(meta["fps"][0])) +# duration_s = float(meta["duration"][0]) +# num_src_frames = int(round(duration_s * fps)) # robust length estimate + +# if max_frames < 0: +# max_frames = max(int(num_src_frames / fps * target_fps + max_frames), 0) + +# frame_nos = resample( +# fps, num_src_frames, +# max_target_frames_count=max_frames, +# target_fps=target_fps, +# start_target_frame=start_frame +# ) +# if len(frame_nos) == 0: +# return torch.empty((0,)) # nothing to return + +# target_ts = [i / fps for i in frame_nos] + +# # Read forward once, grabbing frames when we pass each target timestamp +# frames = [] +# vr.seek(target_ts[0]) +# idx = 0 +# tol = 0.5 / fps # half-frame tolerance +# for frame in vr: +# t = float(frame["pts"]) # seconds +# if idx < len(target_ts) and t + tol >= target_ts[idx]: +# frames.append(frame["data"].permute(1,2,0)) # Tensor [H, W, C] +# idx += 1 +# if idx >= len(target_ts): +# break + +# return frames + + def get_preprocessor(process_type, inpaint_color): if process_type=="pose": from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator @@ -3514,19 +3664,22 @@ def get_preprocessor(process_type, inpaint_color): def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : if not items: return [] - max_workers = 11 + import concurrent.futures start_time = time.time() # print(f"Preprocessus:{process_type} started") if process_type in ["prephase", "upsample"]: if wrap_in_list : items = [ [img] for img in items] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} - results = [None] * len(items) - for future in concurrent.futures.as_completed(futures): - idx = futures[future] - results[idx] = future.result() + if max_workers == 1: + results = [image_processor(img) for img in items] + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() if wrap_in_list: results = [ img[0] for img in results] @@ -3536,10 +3689,68 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis end_time = time.time() # print(f"duration:{end_time-start_time:.1f}") - return results + return results -def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): - from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): + if not input_video_path or max_frames <= 0: + return None, None + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + + any_mask = input_mask_path != None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if len(video) == 0: return None + frame_height, frame_width, _ = video[0].shape + num_frames = len(video) + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + num_frames = min(num_frames, len(mask_video)) + if num_frames == 0: return None + video = video[:num_frames] + if any_mask: + mask_video = mask_video[:num_frames] + + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + + face_list = [] + for frame_idx in range(num_frames): + frame = video[frame_idx].cpu().numpy() + # video[frame_idx] = None + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) + # mask_video[frame_idx] = None + if (frame_width, frame_height) != mask.size: + mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8) + alpha_mask[mask > 127] = 1 + frame = frame * alpha_mask + frame = Image.fromarray(frame) + face = face_processor.process(frame, resize_to=size, face_crop_scale = 1) + face_list.append(face) + + face_processor = None + gc.collect() + torch.cuda.empty_cache() + + face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w + if pad_frames > 0: + face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2) + + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_faces_frames = [np.array(face) for face in face_list ] + save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) + return face_tensor + +def get_default_workers(): + return os.cpu_count()/ 2 + +def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -3554,7 +3765,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box - + inpaint_color = int(inpaint_color) + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None @@ -3580,6 +3797,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, if len(video) == 0 or any_mask and len(mask_video) == 0: return None, None + if fit_crop and outpainting_dims != None: + fit_crop = False + fit_canvas = 0 if fit_canvas is not None else None frame_height, frame_width, _ = video[0].shape @@ -3594,7 +3814,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, if outpainting_dims != None: final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) + height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) if any_mask: num_frames = min(len(video), len(mask_video)) @@ -3611,14 +3831,20 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, # for frame_idx in range(num_frames): def prep_prephase(frame_idx): frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() - frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, width, height) + else: + frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) frame = np.array(frame) if any_mask: if any_identity_mask: mask = np.full( (height, width, 3), 0, dtype= np.uint8) else: mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy() - mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + mask = rescale_and_crop(mask, width, height) + else: + mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) mask = np.array(mask) if len(mask.shape) == 3 and mask.shape[2] == 3: @@ -3649,8 +3875,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, return (target_frame, frame, mask) else: return (target_frame, None, None) - - proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + max_workers = get_default_workers() + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers) proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) for frame_idx, frame_group in enumerate(proc_lists): proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group @@ -3659,11 +3885,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, mask_video = None if preproc2 != None: - proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers) #### to be finished ...or not - proc_list = process_images_multithread(preproc, proc_list, process_type) + proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers) if any_mask: - proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers) else: proc_list_outside = proc_mask = len(proc_list) * [None] @@ -3681,7 +3907,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device) full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame - masks.append(mask) + masks.append(mask[:, :, 0:1].clone()) else: masked_frame = processed_img @@ -3701,28 +3927,33 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None - if args.save_masks: - from preprocessing.dwpose.pose import save_one_video - saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) - if any_mask: - saved_masks = [mask.cpu().numpy() for mask in masks ] - save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + # if args.save_masks: + # from preprocessing.dwpose.pose import save_one_video + # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + # if any_mask: + # saved_masks = [mask.cpu().numpy() for mask in masks ] + # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None preproc_outside = None gc.collect() torch.cuda.empty_cache() + if pad_frames > 0: + masked_frames = masked_frames[0] * pad_frames + masked_frames + if any_mask: masked_frames = masks[0] * pad_frames + masks + masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.) + masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None - return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + return masked_frames, masks -def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, target_fps = 16, block_size = 16): +def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) if len(frames_list) == 0: return None - if fit_canvas == None: + if fit_canvas == None or fit_crop: new_height = height new_width = width else: @@ -3740,7 +3971,10 @@ def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_can processed_frames_list = [] for frame in frames_list: frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8)) - frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, new_width, new_height) + else: + frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) processed_frames_list.append(frame) np_frames = [np.array(frame) for frame in processed_frames_list] @@ -3839,7 +4073,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] def upsample_frames(frame): return resize_lanczos(frame, h, w).unsqueeze(1) - sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1) frames_to_upsample = None return sample @@ -4080,9 +4314,10 @@ def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, ori prompt_images = [] if "I" in prompt_enhancer: if image_start != None: - prompt_images.append(image_start) + prompt_images += image_start if original_image_refs != None: - prompt_images += original_image_refs[:1] + prompt_images += original_image_refs[:1] + prompt_images = [Image.open(img) if isinstance(img,str) else img for img in prompt_images] if len(original_prompts) == 0 and not "T" in prompt_enhancer: return None else: @@ -4152,6 +4387,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri download_models() gen = get_gen_info(state) + original_process_status = None while True: with gen_lock: process_status = gen.get("process_status", None) @@ -4187,7 +4423,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri original_image_refs = inputs["image_refs"] if original_image_refs is not None: original_image_refs = [ convert_image(tup[0]) for tup in original_image_refs ] - is_image = inputs["image_mode"] == 1 + is_image = inputs["image_mode"] > 0 seed = inputs["seed"] seed = set_seed(seed) enhanced_prompts = [] @@ -4219,6 +4455,9 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri gr.Info(f'Prompt "{original_prompts[0][:100]}" has been enhanced') return prompt, prompt +def get_outpainting_dims(video_guide_outpainting): + return None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] + def generate_video( task, send_cmd, @@ -4307,10 +4546,6 @@ def generate_video( model_filename, mode, ): - # import os - # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding - # import torch._logging as tlog - # tlog.set_logs(recompiles=True, guards=True, graph_breaks=True) @@ -4335,7 +4570,7 @@ def generate_video( model_def = get_model_def(model_type) - is_image = image_mode == 1 + is_image = image_mode > 0 if is_image: if min_frames_if_references >= 1000: video_length = min_frames_if_references - 1000 @@ -4346,18 +4581,17 @@ def generate_video( temp_filenames_list = [] if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + video_guide = image_guide + image_guide = None if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + video_mask = image_mask + image_mask = None + if model_def.get("no_background_removal", False): remove_background_images_ref = 0 + base_model_type = get_base_model_type(model_type) model_family = get_model_family(base_model_type) - fit_canvas = server_config.get("fit_canvas", 0) model_handler = get_model_handler(base_model_type) block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16 @@ -4383,7 +4617,7 @@ def generate_video( return width, height = resolution.split("x") - width, height = int(width), int(height) + width, height = int(width) // block_size * block_size, int(height) // block_size * block_size default_image_size = (height, width) if slg_switch == 0: @@ -4442,22 +4676,14 @@ def generate_video( current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 - - i2v = test_class_i2v(model_type) - diffusion_forcing = "diffusion_forcing" in model_filename - t2v = base_model_type in ["t2v"] - ltxv = "ltxv" in model_filename - vace = test_vace_module(base_model_type) - hunyuan_t2v = "hunyuan_video_720" in model_filename - hunyuan_i2v = "hunyuan_video_i2v" in model_filename + guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) + extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) - infinitetalk = base_model_type in ["infinitetalk"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4486,50 +4712,34 @@ def generate_video( if test_any_sliding_window(model_type) : if video_source is not None: - current_video_length += sliding_window_overlap + current_video_length += sliding_window_overlap - 1 sliding_window = current_video_length > sliding_window_size reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) else: sliding_window = False + sliding_window_size = current_video_length reuse_frames = 0 - _, latent_size = get_model_min_frames_and_step(model_type) - if diffusion_forcing: latent_size = 4 + _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs + image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 - frames_to_inject = [] - any_background_ref = False - outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] # Output Video Ratio Priorities: # Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height # Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio + frames_to_inject = [] + any_background_ref = 0 + if "K" in video_prompt_type: + any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1 + + outpainting_dims = get_outpainting_dims(video_guide_outpainting) + fit_canvas = server_config.get("fit_canvas", 0) + fit_crop = fit_canvas == 2 + if fit_crop and outpainting_dims is not None: + fit_crop = False + fit_canvas = 0 - if image_refs is not None and len(image_refs) > 0: - frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else [] - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] - if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) : - from shared.utils.utils import get_outpainting_full_area_dimensions - w, h = image_refs[0].size - if outpainting_dims != None: - h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) - default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) - fit_canvas = None - # if there is a source video and a background image ref, the height/width ratio will need to be processed later by the code for the model (we dont know the source video dimensions at this point) - if len(image_refs) > nb_frames_positions: - any_background_ref = "K" in video_prompt_type - if remove_background_images_ref > 0: - send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") - from shared.utils.utils import resize_and_remove_background - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later - update_task_thumbnails(task, locals()) - send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type) @@ -4554,29 +4764,47 @@ def generate_video( output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide + original_audio_guide2 = audio_guide2 audio_proj_split = None audio_proj_full = None audio_scale = None audio_context_lens = None if (fantasy or multitalk or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None: from models.wan.fantasytalking.infer import parse_audio + from preprocessing.extract_vocals import get_vocals import librosa duration = librosa.get_duration(path=audio_guide) combination_type = "add" + clean_audio_files = "V" in audio_prompt_type if audio_guide2 is not None: duration2 = librosa.get_duration(path=audio_guide2) if "C" in audio_prompt_type: duration += duration2 else: duration = min(duration, duration2) combination_type = "para" if "P" in audio_prompt_type else "add" + if clean_audio_files: + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + audio_guide2 = get_vocals(original_audio_guide2, get_available_filename(save_path, audio_guide2, "_clean2", ".wav")) + temp_filenames_list += [audio_guide, audio_guide2] else: if "X" in audio_prompt_type: + # dual speaker, voice separation from preprocessing.speakers_separator import extract_dual_audio combination_type = "para" if args.save_speakers: audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav" else: audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav") - extract_dual_audio(original_audio_guide, audio_guide, audio_guide2 ) + temp_filenames_list += [audio_guide, audio_guide2] + if clean_audio_files: + clean_audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, original_audio_guide, "_clean", ".wav")) + temp_filenames_list += [clean_audio_guide] + extract_dual_audio(clean_audio_guide if clean_audio_files else original_audio_guide, audio_guide, audio_guide2) + + elif clean_audio_files: + # Single Speaker + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + temp_filenames_list += [audio_guide] + output_new_audio_filepath = original_audio_guide current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length) @@ -4588,10 +4816,10 @@ def generate_video( # pad audio_proj_full if aligned to beginning of window to simulate source window overlap min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration) - if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined - if not args.save_speakers and "X" in audio_prompt_type: - os.remove(audio_guide) - os.remove(audio_guide2) + if output_new_audio_data is not None: # not none if modified + if clean_audio_files: # need to rebuild the sum of audios with original audio + _, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = original_audio_guide, audio_guide2= original_audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration, return_sum_only= True) + output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined if hunyuan_custom_edit and video_guide != None: import cv2 @@ -4599,6 +4827,7 @@ def generate_video( length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) current_video_length = min(current_video_length, length) + seed = set_seed(seed) torch.set_grad_enabled(False) @@ -4615,9 +4844,9 @@ def generate_video( repeat_no = 0 extra_generation = 0 initial_total_windows = 0 - discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length + nb_frames_positions = 0 if sliding_window: initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size @@ -4636,7 +4865,7 @@ def generate_video( if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video, src_mask, src_ref_images = None, None, None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4669,8 +4898,7 @@ def generate_video( while not abort: enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1 - if sliding_window: - prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] + prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] new_extra_windows = gen.get("extra_windows",0) gen["extra_windows"] = 0 extra_windows += new_extra_windows @@ -4691,25 +4919,22 @@ def generate_video( return_latent_slice = None if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - refresh_preview = {"image_guide" : None, "image_mask" : None} + refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images = image_refs image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: - new_height, new_width = calculate_new_dimensions(height, width, image_start.height, image_start.width, sample_fit_canvas, block_size = block_size) - image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_start_tensor, new_height, new_width = calculate_dimensions_and_resize_image(image_start, height, width, sample_fit_canvas, fit_crop, block_size = block_size) + if fit_crop: refresh_preview["image_start"] = image_start_tensor image_start_tensor = convert_image_to_tensor(image_start_tensor) pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) - if image_end is not None: - image_end_tensor = image_end.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - image_end_tensor = convert_image_to_tensor(image_end_tensor) else: - if "L" in image_prompt_type: - refresh_preview["video_source"] = get_video_frame(video_source, 0) - prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size ) + prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size ) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w + if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0) + + new_height, new_width = prefix_video.shape[-2:] pre_video_guide = prefix_video[:, -reuse_frames:] pre_video_frame = convert_tensor_to_image(prefix_video[:, -1]) source_video_overlap_frames_count = pre_video_guide.shape[1] @@ -4718,7 +4943,15 @@ def generate_video( image_size = pre_video_guide.shape[-2:] sample_fit_canvas = None guide_start_frame = prefix_video.shape[1] - + if image_end is not None: + image_end_list= image_end if isinstance(image_end, list) else [image_end] + if len(image_end_list) >= window_no: + new_height, new_width = image_size + image_end_tensor, _, _ = calculate_dimensions_and_resize_image(image_end_list[window_no-1], new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) + # image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + refresh_preview["image_end"] = image_end_tensor + image_end_tensor = convert_image_to_tensor(image_end_tensor) + image_end_list= None window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames) alignment_shift = source_video_frames_count if reset_control_aligment else 0 @@ -4732,118 +4965,168 @@ def generate_video( # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) + if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.replace(","," ").split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no-1 and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + + + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None if video_guide is not None: - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) + keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] - if infinitetalk and video_guide is not None: - src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True) - new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) - src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - if ltxv and video_guide is not None: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") - status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None + guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + extra_control_frames = model_def.get("extra_control_frames", 0) + if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames + + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] + keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] + guide_frames_extract_count = len(keep_frames_parsed) - if t2v and "G" in video_prompt_type: - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) - if video_guide_processed == None: - src_video = pre_video_guide - else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) + # Extract Faces to video + if "B" in video_prompt_type: + send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) + src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) + if src_faces is not None and src_faces.shape[1] < current_video_length: + src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - context_scale = [ control_net_weight] - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") - else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + # Sparse Video to Video + sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame] - - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], - [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], - [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], - current_video_length, image_size = image_size, device ="cpu", - keep_video_guide_frames=keep_frames_parsed, - start_frame = aligned_guide_start_frame, - pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - fit_into_canvas = sample_fit_canvas, - inject_frames= frames_to_inject_parsed, - outpainting_dims = outpainting_dims, - any_background_ref = any_background_ref - ) - if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] - if any_background_ref: - new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + # Generic Video Preprocessing + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") else: - new_image_refs += image_refs[nb_frames_positions:] - refresh_preview["image_refs"] = new_image_refs - new_image_refs = None + preprocess_type2 = process_map_video_guide.get(process_letter, None) + status_info = "Extracting " + processes_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] + if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) + inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size ) - if sample_fit_canvas != None: - image_size = src_video[0].shape[-2:] + if video_guide_processed is not None and sample_fit_canvas is not None: + image_size = video_guide_processed.shape[-2:] sample_fit_canvas = None - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] - else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) + if window_no == 1 and image_refs is not None and len(image_refs) > 0: + if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : + from shared.utils.utils import get_outpainting_full_area_dimensions + w, h = image_refs[0].size + if outpainting_dims != None: + h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) + image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) + sample_fit_canvas = None + if repeat_no == 1: + if fit_crop: + if any_background_ref == 2: + end_ref_position = len(image_refs) + elif any_background_ref == 1: + end_ref_position = nb_frames_positions + 1 + else: + end_ref_position = nb_frames_positions + for i, img in enumerate(image_refs[:end_ref_position]): + image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) + refresh_preview["image_refs"] = image_refs + + if len(image_refs) > nb_frames_positions: + src_ref_images = image_refs[nb_frames_positions:] + if remove_background_images_ref > 0: + send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) + + src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0], + remove_background_images_ref > 0, any_background_ref, + fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1), + block_size=block_size, + outpainting_dims =outpainting_dims, + background_ref_outpainted = model_def.get("background_ref_outpainted", True), + return_tensor= model_def.get("return_image_refs_tensor", False) ) + + + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): + any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False) + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), + [video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]), + None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, + image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None + if len(src_videos) == 1: + src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None + else: + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + src_videos = src_masks = None + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) + else: + src_faces = src_faces[:, :src_video.shape[1]] + if video_guide is not None or len(frames_to_inject_parsed) > 0: + if args.save_masks: + if src_video is not None: + save_video( src_video, "masked_frames.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1)) + if video_guide is not None: + preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) + if src_video2 is not None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] + if src_mask is not None and video_mask is not None: + refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True) + + if src_ref_images is not None or nb_frames_positions: + if len(frames_to_inject_parsed): + new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + else: + new_image_refs = [] + if src_ref_images is not None: + new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None + if len(refresh_preview) > 0: new_inputs= locals() new_inputs.update(refresh_preview) @@ -4880,17 +5163,21 @@ def generate_video( input_prompt = prompt, image_start = image_start_tensor, image_end = image_end_tensor, - input_frames = src_video, + input_frames = src_video, + input_frames2 = src_video2, input_ref_images= src_ref_images, + input_ref_masks = src_ref_masks, input_masks = src_mask, + input_masks2 = src_mask2, input_video= pre_video_guide, + input_faces = src_faces, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, batch_size = batch_size, - height = height, - width = width, - fit_into_canvas = fit_canvas == 1, + height = image_size[0], + width = image_size[1], + fit_into_canvas = fit_canvas, shift=flow_shift, sample_solver=sample_solver, sampling_steps=num_inference_steps, @@ -4947,6 +5234,7 @@ def generate_video( pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, + outpainting_dims = outpainting_dims, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -5014,6 +5302,7 @@ def generate_video( send_cmd("output") else: sample = samples.cpu() + abort = not is_image and sample.shape[1] < current_video_length # if True: # for testing # torch.save(sample, "output.pt") # else: @@ -5271,7 +5560,7 @@ def process_tasks(state): while True: with gen_lock: process_status = gen.get("process_status", None) - if process_status is None: + if process_status is None or process_status == "process:main": gen["process_status"] = "process:main" break time.sleep(1) @@ -5888,6 +6177,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if target == "settings": return inputs + image_outputs = inputs.get("image_mode",0) > 0 + pop=[] if "force_fps" in inputs and len(inputs["force_fps"])== 0: pop += ["force_fps"] @@ -5902,13 +6193,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"] video_prompt_type = inputs["video_prompt_type"] - if not base_model_type in ["t2v"]: + if not "G" in video_prompt_type: pop += ["denoising_strength"] if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0): pop += ["prompt_enhancer"] - if not recammaster and not diffusion_forcing and not flux: + if model_def.get("model_modes", None) is None: pop += ["model_mode"] if not vace and not phantom and not hunyuan_video_custom: @@ -5922,8 +6213,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["image_refs_relative_size"] if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] + pop += ["frames_positions", "control_net_weight", "control_net_weight2"] + if model_def.get("video_guide_outpainting", None) is None: + pop += ["video_guide_outpainting"] + if not (vace or t2v): pop += ["min_frames_if_references"] @@ -5953,7 +6247,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if guidance_max_phases < 3 or guidance_phases < 3: pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"] - if ltxv: + if ltxv or image_outputs: pop += ["flow_shift"] if model_def.get("no_negative_prompt", False) : @@ -6012,10 +6306,16 @@ def video_to_source_video(state, input_file_list, choice): def image_to_ref_image_add(state, input_file_list, choice, target, target_name): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info(f"Selected Image was added to {target_name}") - if target == None: - target =[] - target.append( file_list[choice]) + model_type = state["model_type"] + model_def = get_model_def(model_type) + if model_def.get("one_image_ref_needed", False): + gr.Info(f"Selected Image was set to {target_name}") + target =[file_list[choice]] + else: + gr.Info(f"Selected Image was added to {target_name}") + if target == None: + target =[] + target.append( file_list[choice]) return target def image_to_ref_image_set(state, input_file_list, choice, target, target_name): @@ -6024,6 +6324,18 @@ def image_to_ref_image_set(state, input_file_list, choice, target, target_name): gr.Info(f"Selected Image was copied to {target_name}") return file_list[choice] +def image_to_ref_image_guide(state, input_file_list, choice): + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update(), gr.update() + ui_settings = get_current_model_settings(state) + gr.Info(f"Selected Image was copied to Control Image") + new_image = file_list[choice] + if ui_settings["image_mode"]==2 or True: + return new_image, new_image + else: + return new_image, None + + def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation): gen = get_gen_info(state) @@ -6090,14 +6402,6 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) -def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1] - return extension in [".mp4"] - -def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1] - return extension in [".jpeg", ".jpg", ".png", ".webp", ".bmp", ".tiff"] - def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -6228,6 +6532,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw if not "WanGP" in configs.get("type", ""): configs = None except: configs = None + if configs is None: return None, False current_model_filename = state["model_filename"] @@ -6257,7 +6562,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw return configs, any_image_or_video def record_image_mode_tab(state, evt:gr.SelectData): - state["image_mode_tab"] = 0 if evt.index ==0 else 1 + state["image_mode_tab"] = evt.index def switch_image_mode(state): image_mode = state.get("image_mode_tab", 0) @@ -6265,7 +6570,18 @@ def switch_image_mode(state): ui_defaults = get_model_settings(state, model_type) ui_defaults["image_mode"] = image_mode - + video_prompt_type = ui_defaults.get("video_prompt_type", "") + model_def = get_model_def( model_type) + inpaint_support = model_def.get("inpaint_support", False) + if inpaint_support: + if image_mode == 1: + video_prompt_type = del_in_sequence(video_prompt_type, "VAG" + all_guide_processes) + video_prompt_type = add_to_sequence(video_prompt_type, "KI") + elif image_mode == 2: + video_prompt_type = del_in_sequence(video_prompt_type, "KI" + all_guide_processes) + video_prompt_type = add_to_sequence(video_prompt_type, "VAG") + ui_defaults["video_prompt_type"] = video_prompt_type + return str(time.time()) def load_settings_from_file(state, file_path): @@ -6298,6 +6614,7 @@ def load_settings_from_file(state, file_path): def save_inputs( target, + image_mask_guide, lset_name, image_mode, prompt, @@ -6383,13 +6700,19 @@ def save_inputs( state, ): - - # if state.get("validate_success",0) != 1: - # return + model_filename = state["model_filename"] model_type = state["model_type"] + if image_mask_guide is not None and image_mode >= 1 and video_prompt_type is not None and "A" in video_prompt_type and not "U" in video_prompt_type: + # if image_mask_guide is not None and image_mode == 2: + if "background" in image_mask_guide: + image_guide = image_mask_guide["background"] + if "layers" in image_mask_guide and len(image_mask_guide["layers"])>0: + image_mask = image_mask_guide["layers"][0] + image_mask_guide = None inputs = get_function_arguments(save_inputs, locals()) inputs.pop("target") + inputs.pop("image_mask_guide") cleaned_inputs = prepare_inputs_dict(target, inputs) if target == "settings": defaults_filename = get_settings_file_name(model_type) @@ -6493,11 +6816,16 @@ def change_model(state, model_choice): return header -def fill_inputs(state): +def get_current_model_settings(state): model_type = state["model_type"] - ui_defaults = get_model_settings(state, model_type) + ui_defaults = get_model_settings(state, model_type) if ui_defaults == None: ui_defaults = get_default_settings(model_type) + set_model_settings(state, model_type, ui_defaults) + return ui_defaults + +def fill_inputs(state): + ui_defaults = get_current_model_settings(state) return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults) @@ -6535,11 +6863,13 @@ def any_letters(source_str, letters): return True return False -def filter_letters(source_str, letters): +def filter_letters(source_str, letters, default= ""): ret = "" for letter in letters: if letter in source_str: ret += letter + if len(ret) == 0: + return default return ret def add_to_sequence(source_str, letters): @@ -6561,16 +6891,34 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux): audio_prompt_type = add_to_sequence(audio_prompt_type, remux) return audio_prompt_type +def refresh_remove_background_sound(state, audio_prompt_type, remove_background_sound): + audio_prompt_type = del_in_sequence(audio_prompt_type, "V") + if remove_background_sound: + audio_prompt_type = add_to_sequence(audio_prompt_type, "V") + return audio_prompt_type + + def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources): audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB") audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) - return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)) + return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)), gr.update(visible= any_letters(audio_prompt_type, "ABX")) -def refresh_image_prompt_type(state, image_prompt_type): - any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 - return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source) +def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio): + image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + any_video_source = len(filter_letters(image_prompt_type, "VL"))>0 + model_def = get_model_def(state["model_type"]) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + end_visible = "E" in image_prompt_types_allowed and any_letters(image_prompt_type, "SVL") + return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible) -def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs): +def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox): + image_prompt_type = del_in_sequence(image_prompt_type, "E") + if end_checkbox: image_prompt_type += "E" + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + return image_prompt_type, gr.update(visible = "E" in image_prompt_type ) + +def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs, image_mode): model_type = state["model_type"] model_def = get_model_def(model_type) image_ref_choices = model_def.get("image_ref_choices", None) @@ -6580,49 +6928,97 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_ video_prompt_type = del_in_sequence(video_prompt_type, "KFI") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) visible = "I" in video_prompt_type - vace= test_vace_module(state["model_type"]) - + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) rm_bg_visible= visible and not model_def.get("no_background_removal", False) img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False) - return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) + return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and any_outpainting ) -def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode): +def switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + if image_mode == 0: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + mask_in_old = "A" in old_video_prompt_type and not "U" in old_video_prompt_type + mask_in_new = "A" in video_prompt_type and not "U" in video_prompt_type + image_mask_guide_value, image_mask_value, image_guide_value = {}, {}, {} + visible = "V" in video_prompt_type + if mask_in_old != mask_in_new: + if mask_in_new: + if old_image_mask_value is None: + image_mask_guide_value["value"] = old_image_guide_value + else: + image_mask_guide_value["value"] = {"background" : old_image_guide_value, "composite" : None, "layers": [rgb_bw_to_rgba_mask(old_image_mask_value)]} + image_guide_value["value"] = image_mask_value["value"] = None + else: + if old_image_mask_guide_value is not None and "background" in old_image_mask_guide_value: + image_guide_value["value"] = old_image_mask_guide_value["background"] + if "layers" in old_image_mask_guide_value: + image_mask_value["value"] = old_image_mask_guide_value["layers"][0] if len(old_image_mask_guide_value["layers"]) >=1 else None + image_mask_guide_value["value"] = {"background" : None, "composite" : None, "layers": []} + + image_mask_guide = gr.update(visible= visible and mask_in_new, **image_mask_guide_value) + image_guide = gr.update(visible = visible and not mask_in_new, **image_guide_value) + image_mask = gr.update(visible = False, **image_mask_value) + return image_mask_guide, image_guide, image_mask + +def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + old_video_prompt_type = video_prompt_type video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) visible= "A" in video_prompt_type model_type = state["model_type"] model_def = get_model_def(model_type) - image_outputs = image_mode == 1 - return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible ) + image_outputs = image_mode > 0 + image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) + return video_prompt_type, gr.update(visible= visible and not image_outputs), image_mask_guide, image_guide, image_mask, gr.update(visible= visible ) def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide): video_prompt_type = del_in_sequence(video_prompt_type, "T") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) return video_prompt_type -def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode): - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV") +all_guide_processes ="PDESLCMUVB" + +def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + model_type = state["model_type"] + model_def = get_model_def(model_type) + old_video_prompt_type = video_prompt_type + if filter_type == "alt": + guide_custom_choices = model_def.get("guide_custom_choices",{}) + letter_filter = guide_custom_choices.get("letters_filter","") + else: + letter_filter = all_guide_processes + video_prompt_type = del_in_sequence(video_prompt_type, letter_filter) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type - model_type = state["model_type"] - base_model_type = get_base_model_type(model_type) + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type - model_def = get_model_def(model_type) - image_outputs = image_mode == 1 - vace= test_vace_module(model_type) + image_outputs = image_mode > 0 keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) - return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) - -def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt): - video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI") - video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) - control_video_visible = "V" in video_prompt_type + image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) + # mask_video_input_visible = image_mode == 0 and mask_visible + mask_preprocessing = model_def.get("mask_preprocessing", None) + if mask_preprocessing is not None: + mask_selector_visible = mask_preprocessing.get("visible", True) + else: + mask_selector_visible = True ref_images_visible = "I" in video_prompt_type - return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ) - -# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): -# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] -# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide) + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ) + + +# def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): +# old_video_prompt_type = video_prompt_type +# model_def = get_model_def(state["model_type"]) +# guide_custom_choices = model_def.get("guide_custom_choices",{}) +# video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter","")) +# video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) +# image_outputs = image_mode > 0 +# control_video_visible = "V" in video_prompt_type +# ref_images_visible = "I" in video_prompt_type +# denoising_strength_visible = "G" in video_prompt_type +# mask_expand_visible = control_video_visible and "A" in video_prompt_type and not "U" in video_prompt_type +# mask_video_input_visible = image_mode == 0 and mask_expand_visible +# image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) +# keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) + +# return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ), gr.update(visible = mask_video_input_visible ), gr.update(visible = mask_expand_visible), image_mask_guide, image_guide, image_mask, gr.update(visible = keep_frames_video_guide_visible) def refresh_preview(state): gen = get_gen_info(state) @@ -6644,9 +7040,12 @@ def get_prompt_labels(multi_prompts_gen_type, image_outputs = False): new_line_text = "each new line of prompt will be used for a window" if multi_prompts_gen_type != 0 else "each new line of prompt will generate " + ("a new image" if image_outputs else "a new video") return "Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", "Prompts (" + new_line_text + ", # lines = comments)" +def get_image_end_label(multi_prompts_gen_type): + return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation" + def refresh_prompt_labels(multi_prompts_gen_type, image_mode): - prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1) - return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label) + prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode > 0) + return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type)) def show_preview_column_modal(state, column_no): column_no = int(column_no) @@ -6778,28 +7177,38 @@ def categorize_resolution(resolution_str): return group return "1440p" -def group_resolutions(resolutions, selected_resolution): +def group_resolutions(model_def, resolutions, selected_resolution): + + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: + selected_group ="Locked" + available_groups = [selected_group ] + selected_group_resolutions = model_resolutions + else: + grouped_resolutions = {} + for resolution in resolutions: + group = categorize_resolution(resolution[1]) + if group not in grouped_resolutions: + grouped_resolutions[group] = [] + grouped_resolutions[group].append(resolution) + + available_groups = [group for group in group_thresholds if group in grouped_resolutions] - grouped_resolutions = {} - for resolution in resolutions: - group = categorize_resolution(resolution[1]) - if group not in grouped_resolutions: - grouped_resolutions[group] = [] - grouped_resolutions[group].append(resolution) - - available_groups = [group for group in group_thresholds if group in grouped_resolutions] - - selected_group = categorize_resolution(selected_resolution) - selected_group_resolutions = grouped_resolutions.get(selected_group, []) - available_groups.reverse() + selected_group = categorize_resolution(selected_resolution) + selected_group_resolutions = grouped_resolutions.get(selected_group, []) + available_groups.reverse() return available_groups, selected_group_resolutions, selected_group def change_resolution_group(state, selected_group): model_type = state["model_type"] model_def = get_model_def(model_type) model_resolutions = model_def.get("resolutions", None) - resolution_choices, _ = get_resolution_choices(None, model_resolutions) - group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + resolution_choices, _ = get_resolution_choices(None, model_resolutions) + if model_resolutions is None: + group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + else: + last_resolution = group_resolution_choices[0][1] + return gr.update(choices= group_resolution_choices, value= last_resolution) last_resolution_per_group = state["last_resolution_per_group"] last_resolution = last_resolution_per_group.get(selected_group, "") @@ -6810,6 +7219,11 @@ def change_resolution_group(state, selected_group): def record_last_resolution(state, resolution): + + model_type = state["model_type"] + model_def = get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: return server_config["last_resolution_choice"] = resolution selected_group = categorize_resolution(resolution) last_resolution_per_group = state["last_resolution_per_group"] @@ -6845,11 +7259,21 @@ def detect_auto_save_form(state, evt:gr.SelectData): return gr.update() def compute_video_length_label(fps, current_video_length): - return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", + if fps is None: + return f"Number of frames" + else: + return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", -def refresh_video_length_label(state, current_video_length): - fps = get_model_fps(get_base_model_type(state["model_type"])) - return gr.update(label= compute_video_length_label(fps, current_video_length)) +def refresh_video_length_label(state, current_video_length, force_fps, video_guide, video_source): + base_model_type = get_base_model_type(state["model_type"]) + computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) + return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) + +def get_default_value(choices, current_value, default_value = None): + for label, value in choices: + if value == current_value: + return current_value + return default_value def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced @@ -6962,7 +7386,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) - model_reference_image = model_def.get("reference_image", False) any_tea_cache = model_def.get("tea_cache", False) any_mag_cache = model_def.get("mag_cache", False) recammaster = base_model_type in ["recam_1.3B"] @@ -6989,321 +7412,245 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non fps = get_model_fps(base_model_type) image_prompt_type_value = "" video_prompt_type_value = "" - any_start_image = False - any_end_image = False - any_reference_image = False + any_start_image = any_end_image = any_reference_image = any_image_mask = False v2i_switch_supported = (vace or t2v or standin) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] + gallery_height = 350 + def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): + with gr.Row(visible = visible) as gallery_row: + gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) + gallery_amg.mount(update_form=update_form) + return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements() image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 ) if not v2i_switch_supported and not image_outputs: image_mode_value = 0 else: - image_outputs = image_mode_value == 1 + image_outputs = image_mode_value > 0 + inpaint_support = model_def.get("inpaint_support", False) image_mode = gr.Number(value =image_mode_value, visible = False) - - with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs: - with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"): + image_mode_tab_selected= "t2i" if image_mode_value == 1 else ("inpaint" if image_mode_value == 2 else "t2v") + with gr.Tabs(visible = v2i_switch_supported or inpaint_support, selected= image_mode_tab_selected ) as image_mode_tabs: + with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab", visible = v2i_switch_supported) as tab_t2v: pass with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): pass + with gr.Tab("Image Inpainting", id = "inpaint", elem_classes="compact_tab", visible=inpaint_support) as tab_inpaint: + pass - - with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: - if vace or infinitetalk: - image_prompt_type_value= ui_defaults.get("image_prompt_type","") - image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value - image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) - - image_start = gr.Gallery(visible = False) - image_end = gr.Gallery(visible = False) - video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) - model_mode = gr.Dropdown(visible = False) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + model_mode_choices = model_def.get("model_modes", None) + model_modes_visibility = [0,1,2] + if model_mode_choices is not None: model_modes_visibility= model_mode_choices.get("image_modes", model_modes_visibility) + + with gr.Column(visible= image_mode_value == 0 and len(image_prompt_types_allowed)> 0 or model_mode_choices is not None and image_mode_value in model_modes_visibility ) as image_prompt_column: + # Video Continue / Start Frame / End Frame + image_prompt_type_value= ui_defaults.get("image_prompt_type","") + image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) + image_prompt_type_choices = [] + if "T" in image_prompt_types_allowed: + image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")] + if "S" in image_prompt_types_allowed: + image_prompt_type_choices += [("Start Video with Image", "S")] + any_start_image = True + if "V" in image_prompt_types_allowed: any_video_source = True - - elif diffusion_forcing or ltxv or ti2v_2_2: - image_prompt_type_value= ui_defaults.get("image_prompt_type","T") - # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) - image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")] - if ltxv: - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if sliding_window_enabled: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) - - # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - if not diffusion_forcing: - model_mode = gr.Dropdown( - choices=[ - ], value=None, - visible= False - ) - else: - model_mode = gr.Dropdown( - choices=[ - ("Synchronous", 0), - ("Asynchronous (better quality but around 50% extra steps added)", 5), - ], - value=ui_defaults.get("model_mode", 0), - label="Generation Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) - elif recammaster: - image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") - image_start = gr.Gallery(value = None, visible = False) - image_end = gr.Gallery(value = None, visible= False) - video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) - model_mode = gr.Dropdown( - choices=[ - ("Pan Right", 1), - ("Pan Left", 2), - ("Tilt Up", 3), - ("Tilt Down", 4), - ("Zoom In", 5), - ("Zoom Out", 6), - ("Translate Up (with rotation)", 7), - ("Translate Down (with rotation)", 8), - ("Arc Left (with rotation)", 9), - ("Arc Right (with rotation)", 10), - ], - value=ui_defaults.get("model_mode", 1), - label="Camera Movement Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(visible=False) - else: - if test_class_i2v(model_type) or hunyuan_i2v: - # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) - image_prompt_type_value= ui_defaults.get("image_prompt_type","S" ) - image_prompt_type_choices = [("Start Video with Image", "S")] - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if not hunyuan_i2v: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - any_start_image = True - any_end_image = True - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - if hunyuan_i2v: - video_source = gr.Video(value=None, visible=False) + image_prompt_type_choices += [("Continue Video", "V")] + if "L" in image_prompt_types_allowed: + any_video_source = True + image_prompt_type_choices += [("Continue Last Video", "L")] + with gr.Group(visible= len(image_prompt_types_allowed)>1 and image_mode_value == 0) as image_prompt_type_group: + with gr.Row(): + image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") + image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "") + if len(image_prompt_type_choices) > 0: + image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - else: - image_prompt_type = gr.Radio(choices=[("", "")], value="") - image_start = gr.Gallery(value=None) - image_end = gr.Gallery(value=None) - video_source = gr.Video(value=None, visible=False) + image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) + if "E" in image_prompt_types_allowed: + image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1) + any_end_image = True + else: + image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) + image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) ) + if model_mode_choices is None or image_mode_value not in model_modes_visibility: model_mode = gr.Dropdown(value=None, visible=False) - keep_frames_video_source = gr.Text(visible=False) + else: + model_mode_value = get_default_value(model_mode_choices["choices"], ui_defaults.get("model_mode", None), model_mode_choices["default"] ) + model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=model_mode_value, label=model_mode_choices["label"], visible=True) + keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: + any_control_video = any_control_image = False + if image_mode_value ==2: + guide_preprocessing = { "selection": ["V", "VG"]} + mask_preprocessing = { "selection": ["A"]} + else: + guide_preprocessing = model_def.get("guide_preprocessing", None) + mask_preprocessing = model_def.get("mask_preprocessing", None) + guide_custom_choices = model_def.get("guide_custom_choices", None) + image_ref_choices = model_def.get("image_ref_choices", None) + + # with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or (flux or qwen ) and model_reference_image and image_mode_value >=1) as video_prompt_column: + with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) - any_control_video = True - any_control_image = image_outputs - with gr.Row(): - if t2v: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Use Text Prompt Only", ""), - ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"), - ], - value=filter_letters(video_prompt_type_value, "GUV"), - label="Video to Video", scale = 2, show_label= False, visible= True - ) - elif vace : + with gr.Row(visible = image_mode_value!=2) as guide_selection_row: + # Control Video Preprocessing + if guide_preprocessing is None: + video_prompt_type_video_guide = gr.Dropdown(choices=[("","")], value="", label="Control Video", scale = 2, visible= False, show_label= True, ) + else: pose_label = "Pose" if image_outputs else "Motion" - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Image" if image_outputs else "No Control Video", ""), - ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"), - (f"Transfer Human {pose_label}" , "PV"), - ("Transfer Depth", "DV"), - ("Transfer Shapes", "SV"), - ("Transfer Flow", "LV"), - ("Recolorize", "CV"), - ("Perform Inpainting", "MV"), - ("Use Vace raw format", "V"), - (f"Transfer Human {pose_label} & Depth", "PDV"), - (f"Transfer Human {pose_label} & Shapes", "PSV"), - (f"Transfer Human {pose_label} & Flow", "PLV"), - ("Transfer Depth & Shapes", "DSV"), - ("Transfer Depth & Flow", "DLV"), - ("Transfer Shapes & Flow", "SLV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMGUV"), - label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True, - ) - elif ltxv: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Video", ""), - ("Transfer Human Motion", "PV"), - ("Transfer Depth", "DV"), - ("Transfer Canny Edges", "EV"), - ("Use LTXV raw format", "V"), - ], - value=filter_letters(video_prompt_type_value, "PDEV"), - label="Control Video Process", scale = 2, visible= True, show_label= True, - ) + guide_preprocessing_labels_all = { + "": "No Control Video", + "UV": "Keep Control Video Unchanged", + "PV": f"Transfer Human {pose_label}", + "DV": "Transfer Depth", + "EV": "Transfer Canny Edges", + "SV": "Transfer Shapes", + "LV": "Transfer Flow", + "CV": "Recolorize", + "MV": "Perform Inpainting", + "V": "Use Vace raw format", + "PDV": f"Transfer Human {pose_label} & Depth", + "PSV": f"Transfer Human {pose_label} & Shapes", + "PLV": f"Transfer Human {pose_label} & Flow" , + "DSV": "Transfer Depth & Shapes", + "DLV": "Transfer Depth & Flow", + "SLV": "Transfer Shapes & Flow", + } + guide_preprocessing_choices = [] + guide_preprocessing_labels = guide_preprocessing.get("labels", {}) + for process_type in guide_preprocessing["selection"]: + process_label = guide_preprocessing_labels.get(process_type, None) + process_label = guide_preprocessing_labels_all.get(process_type,process_type) if process_label is None else process_label + if image_outputs: process_label = process_label.replace("Video", "Image") + guide_preprocessing_choices.append( (process_label, process_type) ) - elif hunyuan_video_custom_edit: + video_prompt_type_video_guide_label = guide_preprocessing.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"), - ("Transfer Human Motion", "PMV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMUV"), - label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, - ) - elif infinitetalk: - video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False) + guide_preprocessing_choices, + value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ), + label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, + ) + any_control_video = True + any_control_image = image_outputs + + # Alternate Control Video Preprocessing / Options + if guide_custom_choices is None: + video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 2 ) else: - any_control_video = False - any_control_image = False - video_prompt_type_video_guide = gr.Dropdown(visible= False) - - if infinitetalk: + video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image") + video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ] + guide_custom_choices_value = get_default_value(video_prompt_type_video_guide_alt_choices, filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"]), guide_custom_choices.get("default", "") ) video_prompt_type_video_guide_alt = gr.Dropdown( - choices=[ - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), - ], - value=filter_letters(video_prompt_type_value, "UVQKI"), - label="Video to Video", scale = 3, visible= True, show_label= False, - ) - else: - video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False) + choices= video_prompt_type_video_guide_alt_choices, + # value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), + value=guide_custom_choices_value, + visible = guide_custom_choices.get("visible", True), + label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2 + ) + any_control_video = True + any_control_image = image_outputs - # video_prompt_video_guide_trigger = gr.Text(visible=False, value="") - if t2v: - video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False) - elif hunyuan_video_custom_edit: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ], - value= filter_letters(video_prompt_type_value, "NA"), - visible= "V" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, - ) - elif ltxv: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ], - value= filter_letters(video_prompt_type_value, "XNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, - ) + # Control Mask Preprocessing + if mask_preprocessing is None: + video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 2, visible= False, show_label= True, ) + any_image_mask = image_outputs else: + mask_preprocessing_labels_all = { + "": "Whole Frame", + "A": "Masked Area", + "NA": "Non Masked Area", + "XA": "Masked Area, rest Inpainted", + "XNA": "Non Masked Area, rest Inpainted", + "YA": "Masked Area, rest Depth", + "YNA": "Non Masked Area, rest Depth", + "WA": "Masked Area, rest Shapes", + "WNA": "Non Masked Area, rest Shapes", + "ZA": "Masked Area, rest Flow", + "ZNA": "Non Masked Area, rest Flow" + } + + mask_preprocessing_choices = [] + mask_preprocessing_labels = mask_preprocessing.get("labels", {}) + for process_type in mask_preprocessing["selection"]: + process_label = mask_preprocessing_labels.get(process_type, None) + process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label + mask_preprocessing_choices.append( (process_label, process_type) ) + + video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ("Masked Area, rest Depth", "YA"), - ("Non Masked Area, rest Depth", "YNA"), - ("Masked Area, rest Shapes", "WA"), - ("Non Masked Area, rest Shapes", "WNA"), - ("Masked Area, rest Flow", "ZA"), - ("Non Masked Area, rest Flow", "ZNA"), - ], - value= filter_letters(video_prompt_type_value, "XYZWNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv, - label="Area Processed", scale = 2, show_label= True, - ) - image_ref_choices = model_def.get("image_ref_choices", None) - if image_ref_choices is not None: - video_prompt_type_image_refs = gr.Dropdown( - choices= image_ref_choices["choices"], - value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), - visible = True, - label=image_ref_choices["label"], show_label= True, scale = 2 - ) - elif t2v: - video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False) - elif vace: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), - ], - value=filter_letters(video_prompt_type_value, "KFI"), - visible = True, - label="Reference Images", show_label= True, scale = 2 - ) - elif standin: # and not vace - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("No Reference Image", ""), - ("Reference Image is a Person Face", "I"), - ], - value=filter_letters(video_prompt_type_value, "I"), - visible = True, - show_label=False, - label="Reference Image", scale = 2 + mask_preprocessing_choices, + value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), + label= video_prompt_type_video_mask_label , scale = 2, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True), + show_label= True, ) - elif (flux or qwen) and model_reference_image: + + # Image Refs Selection + if image_ref_choices is None: video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), - ("Conditional Images are People / Objects", "I"), - ], - value=filter_letters(video_prompt_type_value, "KI"), - visible = True, - show_label=False, - label="Reference Images Combination Method", scale = 2 - ) - else: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], - value=filter_letters(video_prompt_type_value, "KI"), + # choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], + choices=[ ("None", ""),], + value=filter_letters(video_prompt_type_value, ""), visible = False, label="Start / Reference Images", scale = 2 ) - image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) - video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + any_reference_image = False + else: + any_reference_image = True + video_prompt_type_image_refs = gr.Dropdown( + choices= image_ref_choices["choices"], + value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), + visible = image_ref_choices.get("visible", True), + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2 + ) - denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) + image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) + video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + if image_mode_value >= 1: + image_guide_value = ui_defaults.get("image_guide", None) + image_mask_value = ui_defaults.get("image_mask", None) + if image_guide_value is None: + image_mask_guide_value = None + else: + image_mask_guide_value = { "background" : image_guide_value, "composite" : None} + image_mask_guide_value["layers"] = [] if image_mask_value is None else [rgb_bw_to_rgba_mask(image_mask_value)] + + image_mask_guide = gr.ImageEditor( + label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask", + value = image_mask_guide_value, + type='pil', + sources=["upload", "webcam"], + image_mode='RGB', + layers=False, + brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), + # fixed_canvas= True, + # width=800, + height=800, + # transforms=None, + # interactive=True, + elem_id="img_editor", + visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value + ) + any_control_image = True + else: + image_mask_guide = gr.ImageEditor(value = None, visible = False, elem_id="img_editor") + + + denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False) keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last - - with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: + video_guide_outpainting_modes = model_def.get("video_guide_outpainting", []) + with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and image_mode_value in video_guide_outpainting_modes) as video_guide_outpainting_col: video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7311,29 +7658,29 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_bottom = gr.Slider(0, 100, value= video_guide_outpainting_list[1], step=5, label="Bottom %", show_reset_button= False) video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) - any_image_mask = image_outputs and vace - image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None)) - video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) + # image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None)) + image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= False, height = gallery_height, value= ui_defaults.get("image_mask", None)) + video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image - image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""), - type ="pil", show_label= True, - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, - value= ui_defaults.get("image_refs", None), - ) - frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) + image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "") + image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) + + frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) - no_background_removal = model_def.get("no_background_removal", False) + no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None + background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects") + remove_background_images_ref = gr.Dropdown( choices=[ ("Keep Backgrounds behind all Reference Images", 0), - ("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1), + (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal + label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal ) any_audio_voices_support = any_audio_track(base_model_type) @@ -7348,7 +7695,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non speaker_choices=[("None", "")] if any_single_speaker: speaker_choices += [("One Person Speaking Only", "A")] if any_multi_speakers:speaker_choices += [ - ("Two speakers, Auto Separation of Speakers (will work only if there is little background noise)", "XA"), + ("Two speakers, Auto Separation of Speakers (will work only if Voices are distinct)", "XA"), ("Two speakers, Speakers Audio sources are assumed to be played in a Row", "CAB"), ("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB") ] @@ -7363,6 +7710,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = any_audio_voices_support and not image_outputs) as audio_guide_row: audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value ) audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value ) + remove_background_sound = gr.Checkbox(label="Video Motion ignores Background Music (to get a better LipSync)", value="V" in audio_prompt_type_value, visible = any_audio_voices_support and any_letters(audio_prompt_type_value, "ABX") and not image_outputs) with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) and not image_outputs ) as speakers_locations_row: speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True) @@ -7412,14 +7760,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible= True, show_label= not on_demand_prompt_enhancer, ) with gr.Row(): - if server_config.get("fit_canvas", 0) == 1: - label = "Max Resolution (As it maybe less depending on video width / height ratio)" + fit_canvas = server_config.get("fit_canvas", 0) + if fit_canvas == 1: + label = "Outer Box Resolution (one dimension may be less to preserve video W/H ratio)" + elif fit_canvas == 2: + label = "Output Resolution (Input Images wil be Cropped if the W/H ratio is different)" else: - label = "Max Resolution (Pixels will be reallocated depending on the output width / height ratio)" + label = "Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)" current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution model_resolutions = model_def.get("resolutions", None) resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) - available_groups, selected_group_resolutions, selected_group = group_resolutions(resolution_choices, current_resolution_choice) + available_groups, selected_group_resolutions, selected_group = group_resolutions(model_def,resolution_choices, current_resolution_choice) resolution_group = gr.Dropdown( choices = available_groups, value= selected_group, @@ -7438,12 +7789,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: - min_frames, frames_step = get_model_min_frames_and_step(base_model_type) + min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type) current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) + computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , ui_defaults.get("video_guide", None), ui_defaults.get("video_source", None)) video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length, - step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True) + step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True) with gr.Row(visible = not lock_inference_steps) as inference_steps_row: num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) @@ -7615,8 +7967,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") - with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: - gr.Markdown("You may transfer the exising audio tracks of a Control Video") + with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row: + gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ ("No Remux", ""), @@ -7705,7 +8057,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Generate always a 13 Frames long Video (x2.5 slower)",1013), ("Generate always a 17 Frames long Video (x3.0 slower)",1017), ], - value=ui_defaults.get("min_frames_if_references",5 if vace else 1), + value=ui_defaults.get("min_frames_if_references",9 if vace else 1), visible=True, scale = 1, label="Generate more frames to preserve Reference Image Identity / Control Image Information or improve" @@ -7724,7 +8076,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) elif ltxv: sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7735,9 +8087,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) else: # Vace, Multitalk + sliding_window_defaults = model_def.get("sliding_window_defaults", {}) sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)") + sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",0), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7747,19 +8100,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Control Audio temporal alignment when any Source Video", + label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk ) multi_prompts_gen_type = gr.Dropdown( choices=[ - ("Will create new generated Video", 0), + ("Will create a new generated Video added to the Generation Queue", 0), ("Will be used for a new Sliding Window of the same Video Generation", 1), ], value=ui_defaults.get("multi_prompts_gen_type",0), visible=True, scale = 1, - label="Text Prompts separated by a Carriage Return" + label="Images & Text Prompts separated by a Carriage Return" if (any_start_image or any_end_image) else "Text Prompts separated by a Carriage Return" ) with gr.Tab("Misc.", visible = True) as misc_tab: @@ -7830,7 +8183,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non output_trigger = gr.Text(interactive= False, visible=False) refresh_form_trigger = gr.Text(interactive= False, visible=False) fill_wizard_prompt_trigger = gr.Text(interactive= False, visible=False) - saveform_trigger = gr.Text(interactive= False, visible=False) + save_form_trigger = gr.Text(interactive= False, visible=False) with gr.Accordion("Video Info and Late Post Processing & Audio Remuxing", open=False) as video_info_accordion: with gr.Tabs() as video_info_tabs: @@ -7839,16 +8192,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") - video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source) + video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm") with gr.Row(**default_visibility) as image_buttons_row: video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image) - video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) - video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask) + video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False) video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): @@ -7888,7 +8241,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non add_to_queue_trigger = gr.Text(visible = False) with gr.Column(visible= False) as current_gen_column: - with gr.Accordion("Preview", open=False) as queue_accordion: + with gr.Accordion("Preview", open=False): preview = gr.Image(label="Preview", height=200, show_label= False) preview_trigger = gr.Text(visible= False) gen_info = gr.HTML(visible=False, min_height=1) @@ -7925,16 +8278,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num") single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn") - extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, + extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, - video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, + video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, - min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column, + NAG_col, remove_background_sound , speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs, + min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7944,22 +8297,25 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non target_settings = gr.Text(value = "settings", interactive= False, visible= False) last_choice = gr.Number(value =-1, interactive= False, visible= False) - resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution]) + resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution], show_progress="hidden") resolution.change(fn=record_last_resolution, inputs=[state, resolution]) - video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + # video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" ) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) - audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) - image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) - # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ]) - video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) + remove_background_sound.change(fn=refresh_remove_background_sound, inputs=[state, audio_prompt_type, remove_background_sound], outputs=[audio_prompt_type]) + audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row, remove_background_sound]) + image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) + image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") + # video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden") + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) - multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) + multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_left.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_left,gr.State(2)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) @@ -7968,8 +8324,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) - preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) + gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab], show_progress="hidden") + preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview], show_progress="hidden") PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) def refresh_status_async(state, progress=gr.Progress()): gen = get_gen_info(state) @@ -7995,13 +8351,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gen["status_display"] = True return time.time() - start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js = get_js() + start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js, click_brush_js = get_js() status_trigger.change(refresh_status_async, inputs= [state] , outputs= [gen_status], show_progress_on= [gen_status]) output_trigger.change(refresh_gallery, inputs = [state], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn]) + outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn], + show_progress="hidden" + ) preview_column_no.input(show_preview_column_modal, inputs=[state, preview_column_no], outputs=[preview_column_no, modal_image_display, modal_container]) @@ -8017,7 +8375,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gr.on( triggers=[video_info_extract_settings_btn.click, video_info_extract_image_settings_btn.click], fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8026,21 +8385,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non prompt_enhancer_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then( fn=enhance_prompt, inputs =[state, prompt, prompt_enhancer, multi_images_gen_type, override_profile ] , outputs= [prompt, wizard_prompt]) - saveform_trigger.change(fn=validate_wizard_prompt, + save_form_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ) - main_tabs.select(fn=detect_auto_save_form, inputs= [state], outputs= saveform_trigger, trigger_mode="multiple") + main_tabs.select(fn=detect_auto_save_form, inputs= [state], outputs= save_form_trigger, trigger_mode="multiple") video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) @@ -8049,14 +8410,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) - video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] ) + video_info_to_image_guide_btn.click(fn=image_to_ref_image_guide, inputs =[state, output, last_choice], outputs = [image_guide, image_mask_guide]).then(fn=None, inputs=[], outputs=[], js=click_brush_js ) video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] ) video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) video_info_remux_audio_btn.click(fn=remux_audio, inputs =[state, output, last_choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt], show_progress="hidden",).then( fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None).then( @@ -8071,7 +8432,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non lset_name.select(fn=update_lset_type, inputs=[state, lset_name], outputs=save_lset_prompt_drop) export_settings_from_file_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8088,7 +8450,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8096,7 +8459,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non settings_file.upload(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8110,17 +8474,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non refresh_form_trigger.change(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress= "full" if args.debug_gen_form else "hidden", ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ) model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_choice]) model_choice.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8129,7 +8496,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs= [header] ).then(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress="full" if args.debug_gen_form else "hidden", ).then(fn= preload_model_when_switching, inputs=[state], outputs=[gen_status]) @@ -8138,26 +8506,25 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non generate_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs= queue_df + outputs= [queue_df, queue_accordion], + show_progress="hidden", ).then(fn=prepare_generate_video, inputs= [state], outputs= [generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row] ).then(fn=activate_status, inputs= [state], outputs= [status_trigger], - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion] ).then(fn=process_tasks, inputs= [state], outputs= [preview_trigger, output_trigger], + show_progress="hidden", ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] @@ -8264,17 +8631,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non # gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt, add_to_queue_trigger.change(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs=queue_df - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion] + outputs=[queue_df, queue_accordion], + show_progress="hidden", ).then( fn=update_status, inputs = [state], @@ -8286,8 +8651,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs=[modal_container] ) - return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + return ( state, loras_choices, lset_name, resolution, refresh_form_trigger, save_form_trigger, + # video_guide, image_guide, video_mask, image_mask, image_refs, ) @@ -8323,8 +8688,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice fit_canvas_choice = gr.Dropdown( choices=[ - ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0), - ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be Resized to match this pixels Budget, output video height or width may exceed the requested dimensions )", 0), + ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be Resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Output Width and Height (as the Prompt Image/Video will be Cropped to fit exactly these dimensions)", 2), ], value= server_config.get("fit_canvas", 0), label="Generated Video Dimensions when Prompt contains an Image or a Video", @@ -8347,6 +8713,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), + ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"), ], value= attention_mode, label="Attention Type", @@ -8470,7 +8837,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice ("Off", "" ), ], value= compile, - label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)", + label="Compile Transformer : up to 10-20% faster, useful only if multiple gens at same frames no / resolution", interactive= not lock_ui_compile ) @@ -8662,7 +9029,7 @@ def generate_about_tab(): gr.Markdown("- Blackforest Labs for the innovative Flux image generators (https://github.com/black-forest-labs/flux)") gr.Markdown("- Alibaba Qwen Team for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)") gr.Markdown("- Lightricks for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)") - gr.Markdown("- Hugging Face for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") + gr.Markdown("- Hugging Face for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") gr.Markdown("
Huge acknowledgments to these great open source projects used in WanGP:") gr.Markdown("- Rife: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)") gr.Markdown("- DwPose: Open Pose extractor (https://github.com/IDEA-Research/DWPose)") @@ -8671,7 +9038,7 @@ def generate_about_tab(): gr.Markdown("- Pyannote: speaker diarization (https://github.com/pyannote/pyannote-audio)") gr.Markdown("
Special thanks to the following people for their support:") - gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer") + gr.Markdown("- Cocktail Peanuts : QA dpand simple installation via Pinokio.computer") gr.Markdown("- Tophness : created (former) multi tabs and queuing frameworks") gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") gr.Markdown("- Remade_AI : for their awesome Loras collection") @@ -8775,12 +9142,19 @@ def set_new_tab(tab_state, new_tab_no): tab_state["tab_no"] = 0 return gr.Tabs(selected="video_gen") else: + if not download_shared_done: + download_models() vmc_event_handler(True) tab_state["tab_no"] = new_tab_no return gr.Tabs() def select_tab(tab_state, evt:gr.SelectData): - return set_new_tab(tab_state, evt.index) + old_tab_no = tab_state.get("tab_no",0) + if old_tab_no == 0: + saveform_trigger = get_unique_id() + else: + saveform_trigger = gr.update() + return set_new_tab(tab_state, evt.index), saveform_trigger def get_js(): start_quit_timer_js = """ @@ -8883,7 +9257,21 @@ def get_js(): } } """ - return start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js + + click_brush_js = """ + () => { + setTimeout(() => { + const brushButton = document.querySelector('button[aria-label="Brush"]'); + if (brushButton) { + brushButton.click(); + console.log('Brush button clicked'); + } else { + console.log('Brush button not found'); + } + }, 1000); + } """ + + return start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js, click_brush_js def create_ui(): global vmc_event_handler @@ -9214,9 +9602,17 @@ def create_ui(): console.log('Events dispatched for column:', index); } }; - console.log('sendColIndex function attached to window'); - } + + // cancel wheel usage inside image editor + const hit = n => n?.id === "img_editor" || n?.classList?.contains("wheel-pass"); + addEventListener("wheel", e => { + const path = e.composedPath?.() || (() => { let a=[],n=e.target; for(;n;n=n.parentNode||n.host) a.push(n); return a; })(); + if (path.some(hit)) e.stopImmediatePropagation(); + }, { capture: true, passive: true }); + + } + """ if server_config.get("display_stats", 0) == 1: from shared.utils.stats import SystemStatsApp @@ -9247,13 +9643,13 @@ def create_ui(): stats_element = stats_app.get_gradio_element() with gr.Row(): - ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + ( state, loras_choices, lset_name, resolution, refresh_form_trigger, save_form_trigger + # video_guide, image_guide, video_mask, image_mask, image_refs, ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, server_config, video_guide, image_guide, video_mask, image_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) #, video_guide, image_guide, video_mask, image_mask, image_refs) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) @@ -9263,7 +9659,7 @@ def create_ui(): generate_about_tab() if stats_app is not None: stats_app.setup_events(main, state) - main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs, trigger_mode="multiple") + main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= [main_tabs, save_form_trigger], trigger_mode="multiple") return main if __name__ == "__main__": @@ -9289,4 +9685,4 @@ if __name__ == "__main__": else: url = "http://" + server_name webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True) - demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path})) \ No newline at end of file + demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path}))