mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
LTXV and Flux updates
This commit is contained in:
parent
a356c6af4b
commit
a1c228054c
100
README.md
100
README.md
@ -20,6 +20,21 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||||
|
|
||||||
## 🔥 Latest Updates
|
## 🔥 Latest Updates
|
||||||
|
### July 21 2025: WanGP v7.1
|
||||||
|
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
|
||||||
|
|
||||||
|
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
|
||||||
|
|
||||||
|
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
|
||||||
|
|
||||||
|
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
|
||||||
|
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
|
||||||
|
|
||||||
|
And Also:
|
||||||
|
- easier way to select video resolution
|
||||||
|
- started to optimize Matanyone to reduce VRAM requirements
|
||||||
|
|
||||||
|
|
||||||
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
|
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
|
||||||
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
|
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
|
||||||
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
|
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
|
||||||
@ -86,84 +101,6 @@ Taking care of your life is not enough, you want new stuff to play with ?
|
|||||||
|
|
||||||
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
|
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
|
||||||
|
|
||||||
### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
|
|
||||||
- Multithreaded preprocessing when possible for faster generations
|
|
||||||
- Multithreaded frames Lanczos Upsampling as a bonus
|
|
||||||
- A new Vace preprocessor : *Flow* to extract fluid motion
|
|
||||||
- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
|
|
||||||
- Injected Frames Outpainting, in case you missed it in WanGP 6.21
|
|
||||||
|
|
||||||
Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
|
|
||||||
|
|
||||||
|
|
||||||
### June 19 2025: WanGP v6.2, Vace even more Powercharged
|
|
||||||
👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
|
|
||||||
- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
|
|
||||||
- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
|
|
||||||
- Upgraded the depth extractor to Depth Anything 2 which is much more detailed
|
|
||||||
|
|
||||||
As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
|
|
||||||
### June 17 2025: WanGP v6.1, Vace Powercharged
|
|
||||||
👋 Lots of improvements for Vace the Mother of all Models:
|
|
||||||
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
|
|
||||||
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
|
|
||||||
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
|
|
||||||
- multiple frames injections: multiples frames can be injected at any location of the video
|
|
||||||
- expand past videos in on click: just select one generated video to expand it
|
|
||||||
|
|
||||||
Of course all these new stuff work on all Vace finetunes (including Vace Fusionix).
|
|
||||||
|
|
||||||
Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary.
|
|
||||||
|
|
||||||
### June 12 2025: WanGP v6.0
|
|
||||||
👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
|
|
||||||
|
|
||||||
To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu):
|
|
||||||
- *Fast Hunyuan Video* : generate model t2v in only 6 steps
|
|
||||||
- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps
|
|
||||||
- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
|
|
||||||
|
|
||||||
One more thing...
|
|
||||||
|
|
||||||
The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
|
|
||||||
|
|
||||||
You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune...
|
|
||||||
|
|
||||||
Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
|
|
||||||
|
|
||||||
### June 11 2025: WanGP v5.5
|
|
||||||
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
|
|
||||||
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
|
|
||||||
|
|
||||||
|
|
||||||
### June 6 2025: WanGP v5.41
|
|
||||||
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
|
|
||||||
You will need to do a *pip install -r requirements.txt*
|
|
||||||
|
|
||||||
### June 6 2025: WanGP v5.4
|
|
||||||
👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
|
|
||||||
Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
|
|
||||||
Also many thanks to Reevoy24 for his repackaging / completing the documentation
|
|
||||||
|
|
||||||
### May 28 2025: WanGP v5.31
|
|
||||||
👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets.
|
|
||||||
VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options.
|
|
||||||
|
|
||||||
### May 26, 2025: WanGP v5.3
|
|
||||||
👋 Settings management revolution! Now you can:
|
|
||||||
- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration
|
|
||||||
- Drag & drop videos to automatically extract their settings metadata
|
|
||||||
- Export/import settings as JSON files for easy sharing and backup
|
|
||||||
|
|
||||||
### May 20, 2025: WanGP v5.2
|
|
||||||
👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid.
|
|
||||||
|
|
||||||
### May 18, 2025: WanGP v5.1
|
|
||||||
👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute!
|
|
||||||
|
|
||||||
### May 17, 2025: WanGP v5.0
|
|
||||||
👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer.
|
|
||||||
|
|
||||||
See full changelog: **[Changelog](docs/CHANGELOG.md)**
|
See full changelog: **[Changelog](docs/CHANGELOG.md)**
|
||||||
|
|
||||||
## 📋 Table of Contents
|
## 📋 Table of Contents
|
||||||
@ -202,6 +139,7 @@ git pull
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## 📦 Installation
|
## 📦 Installation
|
||||||
|
|
||||||
For detailed installation instructions for different GPU generations:
|
For detailed installation instructions for different GPU generations:
|
||||||
@ -224,6 +162,12 @@ For detailed installation instructions for different GPU generations:
|
|||||||
- **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history
|
- **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history
|
||||||
- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
|
- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
|
||||||
|
|
||||||
|
## 📚 Video Guides
|
||||||
|
- Nice Video that explain how to use Vace:\
|
||||||
|
https://www.youtube.com/watch?v=FMo9oN2EAvE
|
||||||
|
- Another Vace guide:\
|
||||||
|
https://www.youtube.com/watch?v=T5jNiEhf9xk
|
||||||
|
|
||||||
## 🔗 Related Projects
|
## 🔗 Related Projects
|
||||||
|
|
||||||
### Other Models for the GPU Poor
|
### Other Models for the GPU Poor
|
||||||
|
|||||||
16
defaults/flux_dev.json
Normal file
16
defaults/flux_dev.json
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Flux 1 Dev 12B",
|
||||||
|
"architecture": "flux",
|
||||||
|
"description": "FLUX.1 Dev is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_bf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_quanto_bf16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"image_outputs": true,
|
||||||
|
"flux-model": "flux-dev"
|
||||||
|
},
|
||||||
|
"prompt": "draw a hat",
|
||||||
|
"resolution": "1280x720",
|
||||||
|
"video_length": 1
|
||||||
|
}
|
||||||
@ -1,12 +1,15 @@
|
|||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
"name": "Flux Dev Kontext 12B",
|
"name": "Flux 1 Dev Kontext 12B",
|
||||||
"architecture": "flux_dev_kontext",
|
"architecture": "flux",
|
||||||
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.",
|
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image and the output dimensions may not match the dimensions of the input image.",
|
||||||
"URLs": [
|
"URLs": [
|
||||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
||||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||||
]
|
],
|
||||||
|
"image_outputs": true,
|
||||||
|
"reference_image": true,
|
||||||
|
"flux-model": "flux-dev-kontext"
|
||||||
},
|
},
|
||||||
"prompt": "add a hat",
|
"prompt": "add a hat",
|
||||||
"resolution": "1280x720",
|
"resolution": "1280x720",
|
||||||
|
|||||||
17
defaults/flux_schnell.json
Normal file
17
defaults/flux_schnell.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Flux 1 Schnell 12B",
|
||||||
|
"architecture": "flux",
|
||||||
|
"description": "FLUX.1 Schnell is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. As a distilled model it requires fewer denoising steps.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-schnell_bf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main//flux1-schnell_quanto_bf16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"image_outputs": true,
|
||||||
|
"flux-model": "flux-schnell"
|
||||||
|
},
|
||||||
|
"prompt": "draw a hat",
|
||||||
|
"resolution": "1280x720",
|
||||||
|
"num_inference_steps": 10,
|
||||||
|
"video_length": 1
|
||||||
|
}
|
||||||
@ -1,14 +1,19 @@
|
|||||||
{
|
{
|
||||||
"model":
|
"model":
|
||||||
{
|
{
|
||||||
"name": "LTX Video 0.9.7 13B",
|
"name": "LTX Video 0.9.8 13B",
|
||||||
"architecture" : "ltxv_13B",
|
"architecture" : "ltxv_13B",
|
||||||
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
"description": "LTX Video is a fast model that can be used to generate very very long videos (up to 1800 frames !).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.8-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
||||||
"URLs": [
|
"URLs": [
|
||||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_bf16.safetensors",
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_dev_bf16.safetensors",
|
||||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors"
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_dev_quanto_bf16_int8.safetensors"
|
||||||
],
|
],
|
||||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
|
"preload_URLs" : [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-pose-control-diffusers.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors"
|
||||||
|
],
|
||||||
|
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-dev.yaml"
|
||||||
},
|
},
|
||||||
"num_inference_steps": 30
|
"num_inference_steps": 30
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
{
|
{
|
||||||
"model":
|
"model":
|
||||||
{
|
{
|
||||||
"name": "LTX Video 0.9.7 Distilled 13B",
|
"name": "LTX Video 0.9.8 Distilled 13B",
|
||||||
"architecture" : "ltxv_13B",
|
"architecture" : "ltxv_13B",
|
||||||
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
"description": "LTX Video is a fast model that can be used to generate very long videos (up to 1800 frames !).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
||||||
"URLs": "ltxv_13B",
|
"URLs": [
|
||||||
"loras": ["https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"],
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_dev_bf16.safetensors",
|
||||||
"loras_multipliers": [ 1 ],
|
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors"
|
||||||
"lock_inference_steps": true,
|
],
|
||||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
|
"preload_URLs" : "ltxv_13B",
|
||||||
|
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml"
|
||||||
},
|
},
|
||||||
"num_inference_steps": 6
|
"num_inference_steps": 6
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
"architecture" : "t2v_1.3B",
|
"architecture" : "t2v_1.3B",
|
||||||
"description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it",
|
"description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it",
|
||||||
"URLs": [
|
"URLs": [
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_bf16.safetensors"
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_mbf16.safetensors"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3,9 +3,8 @@
|
|||||||
{
|
{
|
||||||
"name": "Vace ControlNet 1.3B",
|
"name": "Vace ControlNet 1.3B",
|
||||||
"architecture" : "vace_1.3B",
|
"architecture" : "vace_1.3B",
|
||||||
|
"modules": ["vace_1.3B"],
|
||||||
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
|
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
|
||||||
"URLs": [
|
"URLs": "t2v_1.3B"
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1.3B_mbf16.safetensors"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1,16 +1,106 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## 🔥 Latest News
|
## 🔥 Latest News
|
||||||
|
### July 21 2025: WanGP v7.1
|
||||||
|
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
|
||||||
|
|
||||||
|
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
|
||||||
|
|
||||||
|
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
|
||||||
|
|
||||||
|
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
|
||||||
|
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
|
||||||
|
|
||||||
|
And Also:
|
||||||
|
- easier way to select video resolution
|
||||||
|
- started to optimize Matanyone to reduce VRAM requirements
|
||||||
|
|
||||||
|
|
||||||
|
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
|
||||||
|
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
|
||||||
|
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
|
||||||
|
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
|
||||||
|
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
|
||||||
|
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
|
||||||
|
|
||||||
|
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
|
||||||
|
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
|
||||||
|
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
|
||||||
|
|
||||||
|
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
|
||||||
|
|
||||||
|
Also in the news:
|
||||||
|
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
|
||||||
|
- *Film Grain* post processing to add a vintage look at your video
|
||||||
|
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
|
||||||
|
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
|
||||||
|
|
||||||
|
|
||||||
|
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
|
||||||
|
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
|
||||||
|
|
||||||
|
So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models.
|
||||||
|
|
||||||
|
Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters.
|
||||||
|
|
||||||
|
The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence.
|
||||||
|
|
||||||
|
### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
|
||||||
|
**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
|
||||||
|
|
||||||
|
Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
|
||||||
|
|
||||||
|
And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
|
||||||
|
|
||||||
|
As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
|
||||||
|
|
||||||
|
But wait, there is more:
|
||||||
|
- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
|
||||||
|
- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
|
||||||
|
|
||||||
|
Also, of interest too:
|
||||||
|
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
|
||||||
|
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
|
||||||
|
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
|
||||||
|
|
||||||
|
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
|
||||||
|
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
|
||||||
|
- In one click use the newly generated video as a Control Video or Source Video to be continued
|
||||||
|
- Manage multiple settings for the same model and switch between them using a dropdown box
|
||||||
|
- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open
|
||||||
|
- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
|
||||||
|
|
||||||
|
Taking care of your life is not enough, you want new stuff to play with ?
|
||||||
|
- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
|
||||||
|
- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
|
||||||
|
- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
|
||||||
|
- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
|
||||||
|
- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video
|
||||||
|
- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer
|
||||||
|
- Choice of Wan Samplers / Schedulers
|
||||||
|
- More Lora formats support
|
||||||
|
|
||||||
|
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
|
||||||
|
|
||||||
|
### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
|
||||||
|
- Multithreaded preprocessing when possible for faster generations
|
||||||
|
- Multithreaded frames Lanczos Upsampling as a bonus
|
||||||
|
- A new Vace preprocessor : *Flow* to extract fluid motion
|
||||||
|
- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
|
||||||
|
- Injected Frames Outpainting, in case you missed it in WanGP 6.21
|
||||||
|
|
||||||
|
Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
|
||||||
|
|
||||||
|
|
||||||
### June 19 2025: WanGP v6.2, Vace even more Powercharged
|
### June 19 2025: WanGP v6.2, Vace even more Powercharged
|
||||||
Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
|
👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
|
||||||
- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
|
- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
|
||||||
- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
|
- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
|
||||||
- Upgraded the depth extractor to Depth Anything 2 which is much more detailed
|
- Upgraded the depth extractor to Depth Anything 2 which is much more detailed
|
||||||
|
|
||||||
As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
|
As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
|
||||||
|
|
||||||
### June 17 2025: WanGP v6.1, Vace Powercharged
|
### June 17 2025: WanGP v6.1, Vace Powercharged
|
||||||
Lots of improvements for Vace the Mother of all Models:
|
👋 Lots of improvements for Vace the Mother of all Models:
|
||||||
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
|
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
|
||||||
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
|
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
|
||||||
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
|
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
|
||||||
@ -37,22 +127,6 @@ You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for
|
|||||||
|
|
||||||
Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
|
Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
|
||||||
|
|
||||||
### June 12 2025: WanGP v5.6
|
|
||||||
👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
|
|
||||||
|
|
||||||
To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu):
|
|
||||||
- *Fast Hunyuan Video* : generate model t2v in only 6 steps
|
|
||||||
- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps
|
|
||||||
- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
|
|
||||||
|
|
||||||
One more thing...
|
|
||||||
|
|
||||||
The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
|
|
||||||
|
|
||||||
You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune...
|
|
||||||
|
|
||||||
Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
|
|
||||||
|
|
||||||
### June 11 2025: WanGP v5.5
|
### June 11 2025: WanGP v5.5
|
||||||
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
|
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
|
||||||
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
|
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class model_factory:
|
|||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
model_filename = None,
|
model_filename = None,
|
||||||
model_type = None,
|
model_type = None,
|
||||||
|
model_def = None,
|
||||||
base_model_type = None,
|
base_model_type = None,
|
||||||
text_encoder_filename = None,
|
text_encoder_filename = None,
|
||||||
quantizeTransformer = False,
|
quantizeTransformer = False,
|
||||||
@ -35,15 +36,20 @@ class model_factory:
|
|||||||
self.VAE_dtype = VAE_dtype
|
self.VAE_dtype = VAE_dtype
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
torch_device = "cpu"
|
torch_device = "cpu"
|
||||||
|
# model_filename = ["c:/temp/flux1-schnell.safetensors"]
|
||||||
|
|
||||||
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
|
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
|
||||||
self.clip = load_clip(torch_device)
|
self.clip = load_clip(torch_device)
|
||||||
self.name= "flux-dev-kontext"
|
self.name = model_def.get("flux-model", "flux-dev")
|
||||||
|
# self.name= "flux-dev-kontext"
|
||||||
|
# self.name= "flux-dev"
|
||||||
|
# self.name= "flux-schnell"
|
||||||
self.model = load_flow_model(self.name, model_filename[0], torch_device)
|
self.model = load_flow_model(self.name, model_filename[0], torch_device)
|
||||||
|
|
||||||
self.vae = load_ae(self.name, device=torch_device)
|
self.vae = load_ae(self.name, device=torch_device)
|
||||||
|
|
||||||
# offload.change_dtype(self.model, dtype, True)
|
# offload.change_dtype(self.model, dtype, True)
|
||||||
|
# offload.save_model(self.model, "flux-dev.safetensors")
|
||||||
if save_quantized:
|
if save_quantized:
|
||||||
from wgp import save_quantized_model
|
from wgp import save_quantized_model
|
||||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
|
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
|
||||||
@ -61,7 +67,7 @@ class model_factory:
|
|||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
width= 832,
|
width= 832,
|
||||||
height=480,
|
height=480,
|
||||||
guide_scale: float = 2.5,
|
embedded_guidance_scale: float = 2.5,
|
||||||
fit_into_canvas = None,
|
fit_into_canvas = None,
|
||||||
callback = None,
|
callback = None,
|
||||||
loras_slists = None,
|
loras_slists = None,
|
||||||
@ -77,6 +83,8 @@ class model_factory:
|
|||||||
image_ref = input_ref_images[0]
|
image_ref = input_ref_images[0]
|
||||||
w, h = image_ref.size
|
w, h = image_ref.size
|
||||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||||
|
else:
|
||||||
|
image_ref = None
|
||||||
|
|
||||||
inp, height, width = prepare_kontext(
|
inp, height, width = prepare_kontext(
|
||||||
t5=self.t5,
|
t5=self.t5,
|
||||||
@ -96,7 +104,7 @@ class model_factory:
|
|||||||
def unpack_latent(x):
|
def unpack_latent(x):
|
||||||
return unpack(x.float(), height, width)
|
return unpack(x.float(), height, width)
|
||||||
# denoise initial noise
|
# denoise initial noise
|
||||||
x = denoise(self.model, **inp, timesteps=timesteps, guidance=guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent)
|
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent)
|
||||||
if x==None: return None
|
if x==None: return None
|
||||||
# decode latents to pixel space
|
# decode latents to pixel space
|
||||||
x = unpack_latent(x)
|
x = unpack_latent(x)
|
||||||
@ -107,3 +115,13 @@ class model_factory:
|
|||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def query_model_def(model_type, model_def):
|
||||||
|
flux_model = model_def.get("flux-model", "flux-dev")
|
||||||
|
flux_schnell = flux_model == "flux-schnell"
|
||||||
|
model_def_output = {
|
||||||
|
"image_outputs" : True,
|
||||||
|
}
|
||||||
|
if flux_schnell:
|
||||||
|
model_def_output["no_guidance"] = True
|
||||||
|
|
||||||
|
return model_def_output
|
||||||
@ -85,18 +85,47 @@ class Flux(nn.Module):
|
|||||||
new_sd = {}
|
new_sd = {}
|
||||||
if len(sd) == 0: return sd
|
if len(sd) == 0: return sd
|
||||||
|
|
||||||
|
def swap_scale_shift(weight):
|
||||||
|
shift, scale = weight.chunk(2, dim=0)
|
||||||
|
new_weight = torch.cat([scale, shift], dim=0)
|
||||||
|
return new_weight
|
||||||
|
|
||||||
first_key= next(iter(sd))
|
first_key= next(iter(sd))
|
||||||
if first_key.startswith("transformer."):
|
if first_key.startswith("transformer."):
|
||||||
src_list = [".attn.to_q.", ".attn.to_k.", ".attn.to_v."]
|
root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2",
|
||||||
tgt_list = [".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v."]
|
"time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2",
|
||||||
|
"x_embedder", "context_embedder", "proj_out" ]
|
||||||
|
|
||||||
|
root_tgt = ["time_in.in_layer", "time_in.out_layer", "vector_in.in_layer", "vector_in.out_layer",
|
||||||
|
"guidance_in.in_layer", "guidance_in.out_layer",
|
||||||
|
"img_in", "txt_in", "final_layer.linear" ]
|
||||||
|
|
||||||
|
double_src = ["norm1.linear", "norm1_context.linear", "attn.norm_q", "attn.norm_k", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "attn.to_out.0" ,"attn.to_add_out", "attn.to_out", ".attn.to_", ".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj.", ]
|
||||||
|
double_tgt = ["img_mod.lin", "txt_mod.lin", "img_attn.norm.query_norm", "img_attn.norm.key_norm", "img_mlp.0", "img_mlp.2", "txt_mlp.0", "txt_mlp.2", "img_attn.proj", "txt_attn.proj", "img_attn.proj", ".img_attn.", ".txt_attn.q.", ".txt_attn.k.", ".txt_attn.v."]
|
||||||
|
|
||||||
|
single_src = ["norm.linear", "attn.norm_q", "attn.norm_k", "proj_out",".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."]
|
||||||
|
single_tgt = ["modulation.lin","norm.query_norm", "norm.key_norm", "linear2", ".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v.", ".linear1_mlp."]
|
||||||
|
|
||||||
|
|
||||||
for k,v in sd.items():
|
for k,v in sd.items():
|
||||||
|
if k.startswith("transformer.single_transformer_blocks"):
|
||||||
k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks")
|
k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks")
|
||||||
k = k.replace("transformer.double_transformer_blocks", "diffusion_model.double_blocks")
|
for src, tgt in zip(single_src, single_tgt):
|
||||||
for src, tgt in zip(src_list, tgt_list):
|
k = k.replace(src, tgt)
|
||||||
|
elif k.startswith("transformer.transformer_blocks"):
|
||||||
|
k = k.replace("transformer.transformer_blocks", "diffusion_model.double_blocks")
|
||||||
|
for src, tgt in zip(double_src, double_tgt):
|
||||||
|
k = k.replace(src, tgt)
|
||||||
|
else:
|
||||||
|
k = k.replace("transformer.", "diffusion_model.")
|
||||||
|
for src, tgt in zip(root_src, root_tgt):
|
||||||
k = k.replace(src, tgt)
|
k = k.replace(src, tgt)
|
||||||
|
|
||||||
|
if "norm_out.linear" in k:
|
||||||
|
if "lora_B" in k:
|
||||||
|
v = swap_scale_shift(v)
|
||||||
|
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")
|
||||||
new_sd[k] = v
|
new_sd[k] = v
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -117,9 +117,10 @@ class ModulationOut:
|
|||||||
gate: Tensor
|
gate: Tensor
|
||||||
|
|
||||||
|
|
||||||
def split_mlp(mlp, x, divide = 4):
|
def split_mlp(mlp, x, divide = 8):
|
||||||
x_shape = x.shape
|
x_shape = x.shape
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
|
chunk_size = int(x.shape[0]/divide)
|
||||||
chunk_size = int(x_shape[1]/divide)
|
chunk_size = int(x_shape[1]/divide)
|
||||||
x_chunks = torch.split(x, chunk_size)
|
x_chunks = torch.split(x, chunk_size)
|
||||||
for i, x_chunk in enumerate(x_chunks):
|
for i, x_chunk in enumerate(x_chunks):
|
||||||
|
|||||||
@ -224,6 +224,7 @@ def prepare_kontext(
|
|||||||
if bs == 1 and not isinstance(prompt, str):
|
if bs == 1 and not isinstance(prompt, str):
|
||||||
bs = len(prompt)
|
bs = len(prompt)
|
||||||
|
|
||||||
|
if img_cond != None:
|
||||||
width, height = img_cond.size
|
width, height = img_cond.size
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
|
|
||||||
@ -259,6 +260,11 @@ def prepare_kontext(
|
|||||||
target_width = 8 * width
|
target_width = 8 * width
|
||||||
if target_height is None:
|
if target_height is None:
|
||||||
target_height = 8 * height
|
target_height = 8 * height
|
||||||
|
img_cond_ids = img_cond_ids.to(device)
|
||||||
|
else:
|
||||||
|
img_cond = None
|
||||||
|
img_cond_ids = None
|
||||||
|
img_cond_orig = None
|
||||||
|
|
||||||
img = get_noise(
|
img = get_noise(
|
||||||
bs,
|
bs,
|
||||||
@ -271,7 +277,7 @@ def prepare_kontext(
|
|||||||
|
|
||||||
return_dict = prepare(t5, clip, img, prompt)
|
return_dict = prepare(t5, clip, img, prompt)
|
||||||
return_dict["img_cond_seq"] = img_cond
|
return_dict["img_cond_seq"] = img_cond
|
||||||
return_dict["img_cond_seq_ids"] = img_cond_ids.to(device)
|
return_dict["img_cond_seq_ids"] = img_cond_ids
|
||||||
return_dict["img_cond_orig"] = img_cond_orig
|
return_dict["img_cond_orig"] = img_cond_orig
|
||||||
return return_dict, target_height, target_width
|
return return_dict, target_height, target_width
|
||||||
|
|
||||||
|
|||||||
@ -1,302 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import iglob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
|
||||||
from flux.util import (
|
|
||||||
check_onnx_access_for_trt,
|
|
||||||
configs,
|
|
||||||
load_ae,
|
|
||||||
load_clip,
|
|
||||||
load_flow_model,
|
|
||||||
load_t5,
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
NSFW_THRESHOLD = 0.85
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingOptions:
|
|
||||||
prompt: str
|
|
||||||
width: int
|
|
||||||
height: int
|
|
||||||
num_steps: int
|
|
||||||
guidance: float
|
|
||||||
seed: int | None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
|
||||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the prompt or write a command starting with a slash:\n"
|
|
||||||
"- '/w <width>' will set the width of the generated image\n"
|
|
||||||
"- '/h <height>' will set the height of the generated image\n"
|
|
||||||
"- '/s <seed>' sets the next seed\n"
|
|
||||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
|
||||||
"- '/n <steps>' sets the number of steps\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/w"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, width = prompt.split()
|
|
||||||
options.width = 16 * (int(width) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/h"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, height = prompt.split()
|
|
||||||
options.height = 16 * (int(height) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/g"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, guidance = prompt.split()
|
|
||||||
options.guidance = float(guidance)
|
|
||||||
print(f"Setting guidance to {options.guidance}")
|
|
||||||
elif prompt.startswith("/s"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, seed = prompt.split()
|
|
||||||
options.seed = int(seed)
|
|
||||||
print(f"Setting seed to {options.seed}")
|
|
||||||
elif prompt.startswith("/n"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, steps = prompt.split()
|
|
||||||
options.num_steps = int(steps)
|
|
||||||
print(f"Setting number of steps to {options.num_steps}")
|
|
||||||
elif prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
if prompt != "":
|
|
||||||
options.prompt = prompt
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main(
|
|
||||||
name: str = "flux-schnell",
|
|
||||||
width: int = 1360,
|
|
||||||
height: int = 768,
|
|
||||||
seed: int | None = None,
|
|
||||||
prompt: str = (
|
|
||||||
"a photo of a forest with mist swirling around the tree trunks. The word "
|
|
||||||
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
|
||||||
),
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
num_steps: int | None = None,
|
|
||||||
loop: bool = False,
|
|
||||||
guidance: float = 2.5,
|
|
||||||
offload: bool = False,
|
|
||||||
output_dir: str = "output",
|
|
||||||
add_sampling_metadata: bool = True,
|
|
||||||
trt: bool = False,
|
|
||||||
trt_transformer_precision: str = "bf16",
|
|
||||||
track_usage: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
|
||||||
single image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the model to load
|
|
||||||
height: height of the sample in pixels (should be a multiple of 16)
|
|
||||||
width: width of the sample in pixels (should be a multiple of 16)
|
|
||||||
seed: Set a seed for sampling
|
|
||||||
output_name: where to save the output image, `{idx}` will be replaced
|
|
||||||
by the index of the sample
|
|
||||||
prompt: Prompt used for sampling
|
|
||||||
device: Pytorch device
|
|
||||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
|
||||||
loop: start an interactive session and sample multiple times
|
|
||||||
guidance: guidance value used for guidance distillation
|
|
||||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
|
||||||
trt: use TensorRT backend for optimized inference
|
|
||||||
trt_transformer_precision: specify transformer precision for inference
|
|
||||||
track_usage: track usage of the model for licensing purposes
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt = prompt.split("|")
|
|
||||||
if len(prompt) == 1:
|
|
||||||
prompt = prompt[0]
|
|
||||||
additional_prompts = None
|
|
||||||
else:
|
|
||||||
additional_prompts = prompt[1:]
|
|
||||||
prompt = prompt[0]
|
|
||||||
|
|
||||||
assert not (
|
|
||||||
(additional_prompts is not None) and loop
|
|
||||||
), "Do not provide additional prompts and set loop to True"
|
|
||||||
|
|
||||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
|
||||||
|
|
||||||
if name not in configs:
|
|
||||||
available = ", ".join(configs.keys())
|
|
||||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
if num_steps is None:
|
|
||||||
num_steps = 4 if name == "flux-schnell" else 50
|
|
||||||
|
|
||||||
# allow for packing and conversion to latent space
|
|
||||||
height = 16 * (height // 16)
|
|
||||||
width = 16 * (width // 16)
|
|
||||||
|
|
||||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
|
||||||
if len(fns) > 0:
|
|
||||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
|
||||||
else:
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
if not trt:
|
|
||||||
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
|
||||||
clip = load_clip(torch_device)
|
|
||||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
|
||||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
|
||||||
else:
|
|
||||||
# lazy import to make install optional
|
|
||||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
|
||||||
|
|
||||||
# Check if we need ONNX model access (which requires authentication for FLUX models)
|
|
||||||
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
|
|
||||||
|
|
||||||
trt_ctx_manager = TRTManager(
|
|
||||||
trt_transformer_precision=trt_transformer_precision,
|
|
||||||
trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"),
|
|
||||||
)
|
|
||||||
engines = trt_ctx_manager.load_engines(
|
|
||||||
model_name=name,
|
|
||||||
module_names={
|
|
||||||
ModuleName.CLIP,
|
|
||||||
ModuleName.TRANSFORMER,
|
|
||||||
ModuleName.T5,
|
|
||||||
ModuleName.VAE,
|
|
||||||
},
|
|
||||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
|
||||||
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
|
||||||
trt_image_height=height,
|
|
||||||
trt_image_width=width,
|
|
||||||
trt_batch_size=1,
|
|
||||||
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
|
|
||||||
trt_static_batch=False,
|
|
||||||
trt_static_shape=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
|
|
||||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
|
||||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
|
||||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
|
||||||
|
|
||||||
rng = torch.Generator(device="cpu")
|
|
||||||
opts = SamplingOptions(
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_steps=num_steps,
|
|
||||||
guidance=guidance,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
|
|
||||||
while opts is not None:
|
|
||||||
if opts.seed is None:
|
|
||||||
opts.seed = rng.seed()
|
|
||||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
# prepare input
|
|
||||||
x = get_noise(
|
|
||||||
1,
|
|
||||||
opts.height,
|
|
||||||
opts.width,
|
|
||||||
device=torch_device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
seed=opts.seed,
|
|
||||||
)
|
|
||||||
opts.seed = None
|
|
||||||
if offload:
|
|
||||||
ae = ae.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
|
||||||
inp = prepare(t5, clip, x, prompt=opts.prompt)
|
|
||||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
|
||||||
|
|
||||||
# offload TEs to CPU, load model to gpu
|
|
||||||
if offload:
|
|
||||||
t5, clip = t5.cpu(), clip.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = model.to(torch_device)
|
|
||||||
|
|
||||||
# denoise initial noise
|
|
||||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
|
||||||
|
|
||||||
# offload model, load autoencoder to gpu
|
|
||||||
if offload:
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
ae.decoder.to(x.device)
|
|
||||||
|
|
||||||
# decode latents to pixel space
|
|
||||||
x = unpack(x.float(), opts.height, opts.width)
|
|
||||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
|
||||||
x = ae.decode(x)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
|
|
||||||
fn = output_name.format(idx=idx)
|
|
||||||
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
|
|
||||||
|
|
||||||
idx = save_image(
|
|
||||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
elif additional_prompts:
|
|
||||||
next_prompt = additional_prompts.pop(0)
|
|
||||||
opts.prompt = next_prompt
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
|
|
||||||
if trt:
|
|
||||||
trt_ctx_manager.stop_runtime()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(main)
|
|
||||||
@ -1,390 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import iglob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
|
|
||||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
|
|
||||||
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingOptions:
|
|
||||||
prompt: str
|
|
||||||
width: int
|
|
||||||
height: int
|
|
||||||
num_steps: int
|
|
||||||
guidance: float
|
|
||||||
seed: int | None
|
|
||||||
img_cond_path: str
|
|
||||||
lora_scale: float | None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
|
||||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the prompt or write a command starting with a slash:\n"
|
|
||||||
"- '/w <width>' will set the width of the generated image\n"
|
|
||||||
"- '/h <height>' will set the height of the generated image\n"
|
|
||||||
"- '/s <seed>' sets the next seed\n"
|
|
||||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
|
||||||
"- '/n <steps>' sets the number of steps\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/w"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, width = prompt.split()
|
|
||||||
options.width = 16 * (int(width) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/h"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, height = prompt.split()
|
|
||||||
options.height = 16 * (int(height) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/g"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, guidance = prompt.split()
|
|
||||||
options.guidance = float(guidance)
|
|
||||||
print(f"Setting guidance to {options.guidance}")
|
|
||||||
elif prompt.startswith("/s"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, seed = prompt.split()
|
|
||||||
options.seed = int(seed)
|
|
||||||
print(f"Setting seed to {options.seed}")
|
|
||||||
elif prompt.startswith("/n"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, steps = prompt.split()
|
|
||||||
options.num_steps = int(steps)
|
|
||||||
print(f"Setting number of steps to {options.num_steps}")
|
|
||||||
elif prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
if prompt != "":
|
|
||||||
options.prompt = prompt
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
|
||||||
if options is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
img_cond_path = input(user_question)
|
|
||||||
|
|
||||||
if img_cond_path.startswith("/"):
|
|
||||||
if img_cond_path.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not img_cond_path.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if img_cond_path == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
|
||||||
(".jpg", ".jpeg", ".png", ".webp")
|
|
||||||
):
|
|
||||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
|
||||||
continue
|
|
||||||
|
|
||||||
options.img_cond_path = img_cond_path
|
|
||||||
break
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
|
|
||||||
changed = False
|
|
||||||
|
|
||||||
if options is None:
|
|
||||||
return None, changed
|
|
||||||
|
|
||||||
user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the lora scale or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None, changed
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
if prompt != "":
|
|
||||||
options.lora_scale = float(prompt)
|
|
||||||
changed = True
|
|
||||||
return options, changed
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main(
|
|
||||||
name: str,
|
|
||||||
width: int = 1024,
|
|
||||||
height: int = 1024,
|
|
||||||
seed: int | None = None,
|
|
||||||
prompt: str = "a robot made out of gold",
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
num_steps: int = 50,
|
|
||||||
loop: bool = False,
|
|
||||||
guidance: float | None = None,
|
|
||||||
offload: bool = False,
|
|
||||||
output_dir: str = "output",
|
|
||||||
add_sampling_metadata: bool = True,
|
|
||||||
img_cond_path: str = "assets/robot.webp",
|
|
||||||
lora_scale: float | None = 0.85,
|
|
||||||
trt: bool = False,
|
|
||||||
trt_transformer_precision: str = "bf16",
|
|
||||||
track_usage: bool = False,
|
|
||||||
**kwargs: dict | None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
|
||||||
single image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
height: height of the sample in pixels (should be a multiple of 16)
|
|
||||||
width: width of the sample in pixels (should be a multiple of 16)
|
|
||||||
seed: Set a seed for sampling
|
|
||||||
output_name: where to save the output image, `{idx}` will be replaced
|
|
||||||
by the index of the sample
|
|
||||||
prompt: Prompt used for sampling
|
|
||||||
device: Pytorch device
|
|
||||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
|
||||||
loop: start an interactive session and sample multiple times
|
|
||||||
guidance: guidance value used for guidance distillation
|
|
||||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
|
||||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
|
||||||
trt: use TensorRT backend for optimized inference
|
|
||||||
trt_transformer_precision: specify transformer precision for inference
|
|
||||||
track_usage: track usage of the model for licensing purposes
|
|
||||||
"""
|
|
||||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
|
||||||
|
|
||||||
if "lora" in name:
|
|
||||||
assert not trt, "TRT does not support LORA"
|
|
||||||
assert name in [
|
|
||||||
"flux-dev-canny",
|
|
||||||
"flux-dev-depth",
|
|
||||||
"flux-dev-canny-lora",
|
|
||||||
"flux-dev-depth-lora",
|
|
||||||
], f"Got unknown model name: {name}"
|
|
||||||
|
|
||||||
if guidance is None:
|
|
||||||
if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
|
|
||||||
guidance = 30.0
|
|
||||||
elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
|
|
||||||
guidance = 10.0
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if name not in configs:
|
|
||||||
available = ", ".join(configs.keys())
|
|
||||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
|
|
||||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
|
||||||
if len(fns) > 0:
|
|
||||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
|
||||||
else:
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
|
|
||||||
img_embedder = DepthImageEncoder(torch_device)
|
|
||||||
elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
|
|
||||||
img_embedder = CannyImageEncoder(torch_device)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not trt:
|
|
||||||
# init all components
|
|
||||||
t5 = load_t5(torch_device, max_length=512)
|
|
||||||
clip = load_clip(torch_device)
|
|
||||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
|
||||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
|
||||||
else:
|
|
||||||
# lazy import to make install optional
|
|
||||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
|
||||||
|
|
||||||
trt_ctx_manager = TRTManager(
|
|
||||||
trt_transformer_precision=trt_transformer_precision,
|
|
||||||
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
|
|
||||||
)
|
|
||||||
|
|
||||||
engines = trt_ctx_manager.load_engines(
|
|
||||||
model_name=name,
|
|
||||||
module_names={
|
|
||||||
ModuleName.CLIP,
|
|
||||||
ModuleName.TRANSFORMER,
|
|
||||||
ModuleName.T5,
|
|
||||||
ModuleName.VAE,
|
|
||||||
ModuleName.VAE_ENCODER,
|
|
||||||
},
|
|
||||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
|
||||||
custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
|
||||||
trt_image_height=height,
|
|
||||||
trt_image_width=width,
|
|
||||||
trt_batch_size=1,
|
|
||||||
trt_static_batch=kwargs.get("static_batch", True),
|
|
||||||
trt_static_shape=kwargs.get("static_shape", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
|
|
||||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
|
||||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
|
||||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
|
||||||
|
|
||||||
# set lora scale
|
|
||||||
if "lora" in name and lora_scale is not None:
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
if hasattr(module, "set_scale"):
|
|
||||||
module.set_scale(lora_scale)
|
|
||||||
|
|
||||||
rng = torch.Generator(device="cpu")
|
|
||||||
opts = SamplingOptions(
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_steps=num_steps,
|
|
||||||
guidance=guidance,
|
|
||||||
seed=seed,
|
|
||||||
img_cond_path=img_cond_path,
|
|
||||||
lora_scale=lora_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
if "lora" in name:
|
|
||||||
opts, changed = parse_lora_scale(opts)
|
|
||||||
if changed:
|
|
||||||
# update the lora scale:
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
if hasattr(module, "set_scale"):
|
|
||||||
module.set_scale(opts.lora_scale)
|
|
||||||
|
|
||||||
while opts is not None:
|
|
||||||
if opts.seed is None:
|
|
||||||
opts.seed = rng.seed()
|
|
||||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
# prepare input
|
|
||||||
x = get_noise(
|
|
||||||
1,
|
|
||||||
opts.height,
|
|
||||||
opts.width,
|
|
||||||
device=torch_device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
seed=opts.seed,
|
|
||||||
)
|
|
||||||
opts.seed = None
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
|
||||||
inp = prepare_control(
|
|
||||||
t5,
|
|
||||||
clip,
|
|
||||||
x,
|
|
||||||
prompt=opts.prompt,
|
|
||||||
ae=ae,
|
|
||||||
encoder=img_embedder,
|
|
||||||
img_cond_path=opts.img_cond_path,
|
|
||||||
)
|
|
||||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
|
||||||
|
|
||||||
# offload TEs and AE to CPU, load model to gpu
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = model.to(torch_device)
|
|
||||||
|
|
||||||
# denoise initial noise
|
|
||||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
|
||||||
|
|
||||||
# offload model, load autoencoder to gpu
|
|
||||||
if offload:
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
ae.decoder.to(x.device)
|
|
||||||
|
|
||||||
# decode latents to pixel space
|
|
||||||
x = unpack(x.float(), opts.height, opts.width)
|
|
||||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
|
||||||
x = ae.decode(x)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print(f"Done in {t1 - t0:.1f}s")
|
|
||||||
|
|
||||||
idx = save_image(
|
|
||||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
if "lora" in name:
|
|
||||||
opts, changed = parse_lora_scale(opts)
|
|
||||||
if changed:
|
|
||||||
# update the lora scale:
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
if hasattr(module, "set_scale"):
|
|
||||||
module.set_scale(opts.lora_scale)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
|
|
||||||
if trt:
|
|
||||||
trt_ctx_manager.stop_runtime()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(main)
|
|
||||||
@ -1,334 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import iglob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
|
|
||||||
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingOptions:
|
|
||||||
prompt: str
|
|
||||||
width: int
|
|
||||||
height: int
|
|
||||||
num_steps: int
|
|
||||||
guidance: float
|
|
||||||
seed: int | None
|
|
||||||
img_cond_path: str
|
|
||||||
img_mask_path: str
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
|
||||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the prompt or write a command starting with a slash:\n"
|
|
||||||
"- '/s <seed>' sets the next seed\n"
|
|
||||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
|
||||||
"- '/n <steps>' sets the number of steps\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/g"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, guidance = prompt.split()
|
|
||||||
options.guidance = float(guidance)
|
|
||||||
print(f"Setting guidance to {options.guidance}")
|
|
||||||
elif prompt.startswith("/s"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, seed = prompt.split()
|
|
||||||
options.seed = int(seed)
|
|
||||||
print(f"Setting seed to {options.seed}")
|
|
||||||
elif prompt.startswith("/n"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, steps = prompt.split()
|
|
||||||
options.num_steps = int(steps)
|
|
||||||
print(f"Setting number of steps to {options.num_steps}")
|
|
||||||
elif prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
if prompt != "":
|
|
||||||
options.prompt = prompt
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
|
||||||
if options is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
img_cond_path = input(user_question)
|
|
||||||
|
|
||||||
if img_cond_path.startswith("/"):
|
|
||||||
if img_cond_path.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not img_cond_path.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if img_cond_path == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
|
||||||
(".jpg", ".jpeg", ".png", ".webp")
|
|
||||||
):
|
|
||||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
with Image.open(img_cond_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
|
|
||||||
if width % 32 != 0 or height % 32 != 0:
|
|
||||||
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
options.img_cond_path = img_cond_path
|
|
||||||
break
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
|
||||||
if options is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the conditioning mask or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
img_mask_path = input(user_question)
|
|
||||||
|
|
||||||
if img_mask_path.startswith("/"):
|
|
||||||
if img_mask_path.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not img_mask_path.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{img_mask_path}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if img_mask_path == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
|
|
||||||
(".jpg", ".jpeg", ".png", ".webp")
|
|
||||||
):
|
|
||||||
print(f"File '{img_mask_path}' does not exist or is not a valid image file")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
with Image.open(img_mask_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
|
|
||||||
if width % 32 != 0 or height % 32 != 0:
|
|
||||||
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
with Image.open(options.img_cond_path) as img_cond:
|
|
||||||
img_cond_width, img_cond_height = img_cond.size
|
|
||||||
|
|
||||||
if width != img_cond_width or height != img_cond_height:
|
|
||||||
print(
|
|
||||||
f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
options.img_mask_path = img_mask_path
|
|
||||||
break
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main(
|
|
||||||
seed: int | None = None,
|
|
||||||
prompt: str = "a white paper cup",
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
num_steps: int = 50,
|
|
||||||
loop: bool = False,
|
|
||||||
guidance: float = 30.0,
|
|
||||||
offload: bool = False,
|
|
||||||
output_dir: str = "output",
|
|
||||||
add_sampling_metadata: bool = True,
|
|
||||||
img_cond_path: str = "assets/cup.png",
|
|
||||||
img_mask_path: str = "assets/cup_mask.png",
|
|
||||||
track_usage: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
|
||||||
single image. This demo assumes that the conditioning image and mask have
|
|
||||||
the same shape and that height and width are divisible by 32.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: Set a seed for sampling
|
|
||||||
output_name: where to save the output image, `{idx}` will be replaced
|
|
||||||
by the index of the sample
|
|
||||||
prompt: Prompt used for sampling
|
|
||||||
device: Pytorch device
|
|
||||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
|
||||||
loop: start an interactive session and sample multiple times
|
|
||||||
guidance: guidance value used for guidance distillation
|
|
||||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
|
||||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
|
||||||
img_mask_path: path to conditioning mask (jpeg/png/webp)
|
|
||||||
track_usage: track usage of the model for licensing purposes
|
|
||||||
"""
|
|
||||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
|
||||||
|
|
||||||
name = "flux-dev-fill"
|
|
||||||
if name not in configs:
|
|
||||||
available = ", ".join(configs.keys())
|
|
||||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
|
|
||||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
|
||||||
if len(fns) > 0:
|
|
||||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
|
||||||
else:
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
# init all components
|
|
||||||
t5 = load_t5(torch_device, max_length=128)
|
|
||||||
clip = load_clip(torch_device)
|
|
||||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
|
||||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
|
||||||
|
|
||||||
rng = torch.Generator(device="cpu")
|
|
||||||
with Image.open(img_cond_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
opts = SamplingOptions(
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_steps=num_steps,
|
|
||||||
guidance=guidance,
|
|
||||||
seed=seed,
|
|
||||||
img_cond_path=img_cond_path,
|
|
||||||
img_mask_path=img_mask_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
|
|
||||||
with Image.open(opts.img_cond_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
opts.height = height
|
|
||||||
opts.width = width
|
|
||||||
|
|
||||||
opts = parse_img_mask_path(opts)
|
|
||||||
|
|
||||||
while opts is not None:
|
|
||||||
if opts.seed is None:
|
|
||||||
opts.seed = rng.seed()
|
|
||||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
# prepare input
|
|
||||||
x = get_noise(
|
|
||||||
1,
|
|
||||||
opts.height,
|
|
||||||
opts.width,
|
|
||||||
device=torch_device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
seed=opts.seed,
|
|
||||||
)
|
|
||||||
opts.seed = None
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
|
||||||
inp = prepare_fill(
|
|
||||||
t5,
|
|
||||||
clip,
|
|
||||||
x,
|
|
||||||
prompt=opts.prompt,
|
|
||||||
ae=ae,
|
|
||||||
img_cond_path=opts.img_cond_path,
|
|
||||||
mask_path=opts.img_mask_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
|
||||||
|
|
||||||
# offload TEs and AE to CPU, load model to gpu
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = model.to(torch_device)
|
|
||||||
|
|
||||||
# denoise initial noise
|
|
||||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
|
||||||
|
|
||||||
# offload model, load autoencoder to gpu
|
|
||||||
if offload:
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
ae.decoder.to(x.device)
|
|
||||||
|
|
||||||
# decode latents to pixel space
|
|
||||||
x = unpack(x.float(), opts.height, opts.width)
|
|
||||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
|
||||||
x = ae.decode(x)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print(f"Done in {t1 - t0:.1f}s")
|
|
||||||
|
|
||||||
idx = save_image(
|
|
||||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
|
|
||||||
with Image.open(opts.img_cond_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
opts.height = height
|
|
||||||
opts.width = width
|
|
||||||
|
|
||||||
opts = parse_img_mask_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(main)
|
|
||||||
@ -1,368 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import iglob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
|
|
||||||
from flux.content_filters import PixtralContentFilter
|
|
||||||
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
|
|
||||||
from flux.util import (
|
|
||||||
aspect_ratio_to_height_width,
|
|
||||||
check_onnx_access_for_trt,
|
|
||||||
load_ae,
|
|
||||||
load_clip,
|
|
||||||
load_flow_model,
|
|
||||||
load_t5,
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingOptions:
|
|
||||||
prompt: str
|
|
||||||
width: int | None
|
|
||||||
height: int | None
|
|
||||||
num_steps: int
|
|
||||||
guidance: float
|
|
||||||
seed: int | None
|
|
||||||
img_cond_path: str
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
|
||||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the prompt or write a command starting with a slash:\n"
|
|
||||||
"- '/ar <width>:<height>' will set the aspect ratio of the generated image\n"
|
|
||||||
"- '/s <seed>' sets the next seed\n"
|
|
||||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
|
||||||
"- '/n <steps>' sets the number of steps\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/ar"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, ratio_prompt = prompt.split()
|
|
||||||
if ratio_prompt == "auto":
|
|
||||||
options.width = None
|
|
||||||
options.height = None
|
|
||||||
print("Setting resolution to input image resolution.")
|
|
||||||
else:
|
|
||||||
options.width, options.height = aspect_ratio_to_height_width(ratio_prompt)
|
|
||||||
print(f"Setting resolution to {options.width} x {options.height}.")
|
|
||||||
elif prompt.startswith("/h"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, height = prompt.split()
|
|
||||||
if height == "auto":
|
|
||||||
options.height = None
|
|
||||||
else:
|
|
||||||
options.height = 16 * (int(height) // 16)
|
|
||||||
if options.height is not None and options.width is not None:
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"Setting resolution to {options.width} x {options.height}.")
|
|
||||||
elif prompt.startswith("/g"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, guidance = prompt.split()
|
|
||||||
options.guidance = float(guidance)
|
|
||||||
print(f"Setting guidance to {options.guidance}")
|
|
||||||
elif prompt.startswith("/s"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, seed = prompt.split()
|
|
||||||
options.seed = int(seed)
|
|
||||||
print(f"Setting seed to {options.seed}")
|
|
||||||
elif prompt.startswith("/n"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, steps = prompt.split()
|
|
||||||
options.num_steps = int(steps)
|
|
||||||
print(f"Setting number of steps to {options.num_steps}")
|
|
||||||
elif prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
if prompt != "":
|
|
||||||
options.prompt = prompt
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
|
||||||
if options is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write a path to an image directly, leave this field empty "
|
|
||||||
"to repeat the last input image or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit\n\n"
|
|
||||||
"The input image will be edited by FLUX.1 Kontext creating a new image based"
|
|
||||||
"on your instruction prompt."
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
img_cond_path = input(user_question)
|
|
||||||
|
|
||||||
if img_cond_path.startswith("/"):
|
|
||||||
if img_cond_path.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not img_cond_path.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if img_cond_path == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
|
||||||
(".jpg", ".jpeg", ".png", ".webp")
|
|
||||||
):
|
|
||||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
|
||||||
continue
|
|
||||||
|
|
||||||
options.img_cond_path = img_cond_path
|
|
||||||
break
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main(
|
|
||||||
name: str = "flux-dev-kontext",
|
|
||||||
aspect_ratio: str | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
num_steps: int = 30,
|
|
||||||
loop: bool = False,
|
|
||||||
guidance: float = 2.5,
|
|
||||||
offload: bool = False,
|
|
||||||
output_dir: str = "output",
|
|
||||||
add_sampling_metadata: bool = True,
|
|
||||||
img_cond_path: str = "assets/cup.png",
|
|
||||||
trt: bool = False,
|
|
||||||
trt_transformer_precision: str = "bf16",
|
|
||||||
track_usage: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
|
||||||
single image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
height: height of the sample in pixels (should be a multiple of 16), None
|
|
||||||
defaults to the size of the conditioning
|
|
||||||
width: width of the sample in pixels (should be a multiple of 16), None
|
|
||||||
defaults to the size of the conditioning
|
|
||||||
seed: Set a seed for sampling
|
|
||||||
output_name: where to save the output image, `{idx}` will be replaced
|
|
||||||
by the index of the sample
|
|
||||||
prompt: Prompt used for sampling
|
|
||||||
device: Pytorch device
|
|
||||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
|
||||||
loop: start an interactive session and sample multiple times
|
|
||||||
guidance: guidance value used for guidance distillation
|
|
||||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
|
||||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
|
||||||
trt: use TensorRT backend for optimized inference
|
|
||||||
track_usage: track usage of the model for licensing purposes
|
|
||||||
"""
|
|
||||||
assert name == "flux-dev-kontext", f"Got unknown model name: {name}"
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
|
|
||||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
|
||||||
if len(fns) > 0:
|
|
||||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
|
||||||
else:
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
if aspect_ratio is None:
|
|
||||||
width = None
|
|
||||||
height = None
|
|
||||||
else:
|
|
||||||
width, height = aspect_ratio_to_height_width(aspect_ratio)
|
|
||||||
|
|
||||||
if not trt:
|
|
||||||
t5 = load_t5(torch_device, max_length=512)
|
|
||||||
clip = load_clip(torch_device)
|
|
||||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
|
||||||
else:
|
|
||||||
# lazy import to make install optional
|
|
||||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
|
||||||
|
|
||||||
# Check if we need ONNX model access (which requires authentication for FLUX models)
|
|
||||||
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
|
|
||||||
|
|
||||||
trt_ctx_manager = TRTManager(
|
|
||||||
trt_transformer_precision=trt_transformer_precision,
|
|
||||||
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
|
|
||||||
)
|
|
||||||
engines = trt_ctx_manager.load_engines(
|
|
||||||
model_name=name,
|
|
||||||
module_names={
|
|
||||||
ModuleName.CLIP,
|
|
||||||
ModuleName.TRANSFORMER,
|
|
||||||
ModuleName.T5,
|
|
||||||
},
|
|
||||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
|
||||||
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
|
||||||
trt_image_height=height,
|
|
||||||
trt_image_width=width,
|
|
||||||
trt_batch_size=1,
|
|
||||||
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
|
|
||||||
trt_static_batch=False,
|
|
||||||
trt_static_shape=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
|
||||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
|
||||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
|
||||||
|
|
||||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
|
||||||
content_filter = PixtralContentFilter(torch.device("cpu"))
|
|
||||||
|
|
||||||
rng = torch.Generator(device="cpu")
|
|
||||||
opts = SamplingOptions(
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_steps=num_steps,
|
|
||||||
guidance=guidance,
|
|
||||||
seed=seed,
|
|
||||||
img_cond_path=img_cond_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
|
|
||||||
while opts is not None:
|
|
||||||
if opts.seed is None:
|
|
||||||
opts.seed = rng.seed()
|
|
||||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
if content_filter.test_txt(opts.prompt):
|
|
||||||
print("Your prompt has been automatically flagged. Please choose another prompt.")
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
continue
|
|
||||||
if content_filter.test_image(opts.img_cond_path):
|
|
||||||
print("Your input image has been automatically flagged. Please choose another image.")
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
|
||||||
inp, height, width = prepare_kontext(
|
|
||||||
t5=t5,
|
|
||||||
clip=clip,
|
|
||||||
prompt=opts.prompt,
|
|
||||||
ae=ae,
|
|
||||||
img_cond_path=opts.img_cond_path,
|
|
||||||
target_width=opts.width,
|
|
||||||
target_height=opts.height,
|
|
||||||
bs=1,
|
|
||||||
seed=opts.seed,
|
|
||||||
device=torch_device,
|
|
||||||
)
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft")
|
|
||||||
inp.pop("img_cond_orig")
|
|
||||||
opts.seed = None
|
|
||||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
|
||||||
|
|
||||||
# offload TEs and AE to CPU, load model to gpu
|
|
||||||
if offload:
|
|
||||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = model.to(torch_device)
|
|
||||||
|
|
||||||
# denoise initial noise
|
|
||||||
t00 = time.time()
|
|
||||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t01 = time.time()
|
|
||||||
print(f"Denoising took {t01 - t00:.3f}s")
|
|
||||||
|
|
||||||
# offload model, load autoencoder to gpu
|
|
||||||
if offload:
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
ae.decoder.to(x.device)
|
|
||||||
|
|
||||||
# decode latents to pixel space
|
|
||||||
x = unpack(x.float(), height, width)
|
|
||||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
|
||||||
ae_dev_t0 = time.perf_counter()
|
|
||||||
x = ae.decode(x)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
ae_dev_t1 = time.perf_counter()
|
|
||||||
print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s")
|
|
||||||
|
|
||||||
if content_filter.test_image(x.cpu()):
|
|
||||||
print(
|
|
||||||
"Your output image has been automatically flagged. Choose another prompt/image or try again."
|
|
||||||
)
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print(f"Done in {t1 - t0:.1f}s")
|
|
||||||
|
|
||||||
idx = save_image(
|
|
||||||
None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(main)
|
|
||||||
@ -1,290 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import iglob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
from flux.modules.image_embedders import ReduxImageEncoder
|
|
||||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
|
|
||||||
from flux.util import (
|
|
||||||
get_checkpoint_path,
|
|
||||||
load_ae,
|
|
||||||
load_clip,
|
|
||||||
load_flow_model,
|
|
||||||
load_t5,
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingOptions:
|
|
||||||
prompt: str
|
|
||||||
width: int
|
|
||||||
height: int
|
|
||||||
num_steps: int
|
|
||||||
guidance: float
|
|
||||||
seed: int | None
|
|
||||||
img_cond_path: str
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
|
||||||
user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Leave this field empty to do nothing "
|
|
||||||
"or write a command starting with a slash:\n"
|
|
||||||
"- '/w <width>' will set the width of the generated image\n"
|
|
||||||
"- '/h <height>' will set the height of the generated image\n"
|
|
||||||
"- '/s <seed>' sets the next seed\n"
|
|
||||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
|
||||||
"- '/n <steps>' sets the number of steps\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while (prompt := input(user_question)).startswith("/"):
|
|
||||||
if prompt.startswith("/w"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, width = prompt.split()
|
|
||||||
options.width = 16 * (int(width) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/h"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, height = prompt.split()
|
|
||||||
options.height = 16 * (int(height) // 16)
|
|
||||||
print(
|
|
||||||
f"Setting resolution to {options.width} x {options.height} "
|
|
||||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
|
||||||
)
|
|
||||||
elif prompt.startswith("/g"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, guidance = prompt.split()
|
|
||||||
options.guidance = float(guidance)
|
|
||||||
print(f"Setting guidance to {options.guidance}")
|
|
||||||
elif prompt.startswith("/s"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, seed = prompt.split()
|
|
||||||
options.seed = int(seed)
|
|
||||||
print(f"Setting seed to {options.seed}")
|
|
||||||
elif prompt.startswith("/n"):
|
|
||||||
if prompt.count(" ") != 1:
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
continue
|
|
||||||
_, steps = prompt.split()
|
|
||||||
options.num_steps = int(steps)
|
|
||||||
print(f"Setting number of steps to {options.num_steps}")
|
|
||||||
elif prompt.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not prompt.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
|
||||||
if options is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
|
||||||
usage = (
|
|
||||||
"Usage: Either write your prompt directly, leave this field empty "
|
|
||||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
|
||||||
"- '/q' to quit"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
img_cond_path = input(user_question)
|
|
||||||
|
|
||||||
if img_cond_path.startswith("/"):
|
|
||||||
if img_cond_path.startswith("/q"):
|
|
||||||
print("Quitting")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if not img_cond_path.startswith("/h"):
|
|
||||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
|
||||||
print(usage)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if img_cond_path == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
|
||||||
(".jpg", ".jpeg", ".png", ".webp")
|
|
||||||
):
|
|
||||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
|
||||||
continue
|
|
||||||
|
|
||||||
options.img_cond_path = img_cond_path
|
|
||||||
break
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def main(
|
|
||||||
name: str = "flux-dev",
|
|
||||||
width: int = 1360,
|
|
||||||
height: int = 768,
|
|
||||||
seed: int | None = None,
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
num_steps: int | None = None,
|
|
||||||
loop: bool = False,
|
|
||||||
guidance: float = 2.5,
|
|
||||||
offload: bool = False,
|
|
||||||
output_dir: str = "output",
|
|
||||||
add_sampling_metadata: bool = True,
|
|
||||||
img_cond_path: str = "assets/robot.webp",
|
|
||||||
track_usage: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
|
||||||
single image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the base model to use (either 'flux-dev' or 'flux-schnell')
|
|
||||||
height: height of the sample in pixels (should be a multiple of 16)
|
|
||||||
width: width of the sample in pixels (should be a multiple of 16)
|
|
||||||
seed: Set a seed for sampling
|
|
||||||
device: Pytorch device
|
|
||||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
|
||||||
loop: start an interactive session and sample multiple times
|
|
||||||
guidance: guidance value used for guidance distillation
|
|
||||||
offload: offload models to CPU when not in use
|
|
||||||
output_dir: where to save the output images
|
|
||||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
|
||||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
|
||||||
track_usage: track usage of the model for licensing purposes
|
|
||||||
"""
|
|
||||||
|
|
||||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
|
||||||
|
|
||||||
if name not in (available := ["flux-dev", "flux-schnell"]):
|
|
||||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
if num_steps is None:
|
|
||||||
num_steps = 4 if name == "flux-schnell" else 50
|
|
||||||
|
|
||||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
|
||||||
if len(fns) > 0:
|
|
||||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
|
||||||
else:
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
# init all components
|
|
||||||
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
|
||||||
clip = load_clip(torch_device)
|
|
||||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
|
||||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
|
||||||
|
|
||||||
# Download and initialize the Redux adapter
|
|
||||||
redux_path = str(
|
|
||||||
get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX")
|
|
||||||
)
|
|
||||||
img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path)
|
|
||||||
|
|
||||||
rng = torch.Generator(device="cpu")
|
|
||||||
prompt = ""
|
|
||||||
opts = SamplingOptions(
|
|
||||||
prompt=prompt,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_steps=num_steps,
|
|
||||||
guidance=guidance,
|
|
||||||
seed=seed,
|
|
||||||
img_cond_path=img_cond_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
|
|
||||||
while opts is not None:
|
|
||||||
if opts.seed is None:
|
|
||||||
opts.seed = rng.seed()
|
|
||||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
# prepare input
|
|
||||||
x = get_noise(
|
|
||||||
1,
|
|
||||||
opts.height,
|
|
||||||
opts.width,
|
|
||||||
device=torch_device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
seed=opts.seed,
|
|
||||||
)
|
|
||||||
opts.seed = None
|
|
||||||
if offload:
|
|
||||||
ae = ae.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
|
||||||
inp = prepare_redux(
|
|
||||||
t5,
|
|
||||||
clip,
|
|
||||||
x,
|
|
||||||
prompt=opts.prompt,
|
|
||||||
encoder=img_embedder,
|
|
||||||
img_cond_path=opts.img_cond_path,
|
|
||||||
)
|
|
||||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
|
||||||
|
|
||||||
# offload TEs to CPU, load model to gpu
|
|
||||||
if offload:
|
|
||||||
t5, clip = t5.cpu(), clip.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = model.to(torch_device)
|
|
||||||
|
|
||||||
# denoise initial noise
|
|
||||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
|
||||||
|
|
||||||
# offload model, load autoencoder to gpu
|
|
||||||
if offload:
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
ae.decoder.to(x.device)
|
|
||||||
|
|
||||||
# decode latents to pixel space
|
|
||||||
x = unpack(x.float(), opts.height, opts.width)
|
|
||||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
|
||||||
x = ae.decode(x)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print(f"Done in {t1 - t0:.1f}s")
|
|
||||||
|
|
||||||
idx = save_image(
|
|
||||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
if loop:
|
|
||||||
print("-" * 80)
|
|
||||||
opts = parse_prompt(opts)
|
|
||||||
opts = parse_img_cond_path(opts)
|
|
||||||
else:
|
|
||||||
opts = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(main)
|
|
||||||
@ -1,171 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline
|
|
||||||
|
|
||||||
PROMPT_IMAGE_INTEGRITY = """
|
|
||||||
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
|
|
||||||
|
|
||||||
Output: Respond with only "yes" or "no"
|
|
||||||
|
|
||||||
Criteria for "yes":
|
|
||||||
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
|
|
||||||
- The image displays a trademarked logo or brand
|
|
||||||
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
|
|
||||||
|
|
||||||
Criteria for "no":
|
|
||||||
- All other cases
|
|
||||||
- When you cannot identify the specific copyrighted work or named individual
|
|
||||||
|
|
||||||
Critical Requirements:
|
|
||||||
1. You must be able to name the exact copyrighted work or specific person depicted
|
|
||||||
2. General references to demographics or characteristics are not sufficient
|
|
||||||
3. Base your decision solely on visual content, not interpretation
|
|
||||||
4. Provide only the one-word answer: "yes" or "no"
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
|
|
||||||
|
|
||||||
PROMPT_TEXT_INTEGRITY = """
|
|
||||||
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
|
|
||||||
|
|
||||||
Output: Respond with only "yes" or "no"
|
|
||||||
|
|
||||||
Criteria for "Yes":
|
|
||||||
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
|
|
||||||
- The prompt explicitly mentions a trademarked logo or brand
|
|
||||||
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
|
|
||||||
|
|
||||||
Criteria for "No":
|
|
||||||
- All other cases
|
|
||||||
- When you cannot identify the specific copyrighted work or named individual
|
|
||||||
|
|
||||||
Critical Requirements:
|
|
||||||
1. You must be able to name the exact copyrighted work or specific person referenced
|
|
||||||
2. General demographic descriptions or characteristics are not sufficient
|
|
||||||
3. Analyze only the prompt text, not potential image outcomes
|
|
||||||
4. Provide only the one-word answer: "yes" or "no"
|
|
||||||
|
|
||||||
The prompt to check is:
|
|
||||||
-----
|
|
||||||
{prompt}
|
|
||||||
-----
|
|
||||||
|
|
||||||
Does this prompt have copyright concerns or includes public figures?
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
|
|
||||||
class PixtralContentFilter(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
nsfw_threshold: float = 0.85,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
model_id = "mistral-community/pixtral-12b"
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
||||||
self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
|
||||||
|
|
||||||
self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"])
|
|
||||||
|
|
||||||
self.nsfw_classifier = pipeline(
|
|
||||||
"image-classification", model="Falconsai/nsfw_image_detection", device=device
|
|
||||||
)
|
|
||||||
self.nsfw_threshold = nsfw_threshold
|
|
||||||
|
|
||||||
def yes_no_logit_processor(
|
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
"""
|
|
||||||
Sets all tokens but yes/no to the minimum.
|
|
||||||
"""
|
|
||||||
scores_yes_token = scores[:, self.yes_token].clone()
|
|
||||||
scores_no_token = scores[:, self.no_token].clone()
|
|
||||||
scores_min = scores.min()
|
|
||||||
scores[:, :] = scores_min - 1
|
|
||||||
scores[:, self.yes_token] = scores_yes_token
|
|
||||||
scores[:, self.no_token] = scores_no_token
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
|
|
||||||
if isinstance(image, torch.Tensor):
|
|
||||||
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
|
|
||||||
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
|
|
||||||
elif isinstance(image, str):
|
|
||||||
image = Image.open(image)
|
|
||||||
|
|
||||||
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
|
|
||||||
if classification["score"] > self.nsfw_threshold:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 512^2 pixels are enough for checking
|
|
||||||
w, h = image.size
|
|
||||||
f = (512**2 / (w * h)) ** 0.5
|
|
||||||
image = image.resize((int(f * w), int(f * h)))
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"content": PROMPT_IMAGE_INTEGRITY,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": image,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
inputs = self.processor.apply_chat_template(
|
|
||||||
chat,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.model.device)
|
|
||||||
|
|
||||||
generate_ids = self.model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=1,
|
|
||||||
logits_processor=[self.yes_no_logit_processor],
|
|
||||||
do_sample=False,
|
|
||||||
)
|
|
||||||
return generate_ids[0, -1].item() == self.yes_token
|
|
||||||
|
|
||||||
def test_txt(self, txt: str) -> bool:
|
|
||||||
chat = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"content": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
inputs = self.processor.apply_chat_template(
|
|
||||||
chat,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.model.device)
|
|
||||||
|
|
||||||
generate_ids = self.model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=1,
|
|
||||||
logits_processor=[self.yes_no_logit_processor],
|
|
||||||
do_sample=False,
|
|
||||||
)
|
|
||||||
return generate_ids[0, -1].item() == self.yes_token
|
|
||||||
@ -1065,3 +1065,7 @@ class HunyuanVideoSampler(Inference):
|
|||||||
samples = samples.squeeze(0)
|
samples = samples.squeeze(0)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def query_model_def(model_type, model_def):
|
||||||
|
return None
|
||||||
@ -28,10 +28,6 @@ def get_linear_split_map():
|
|||||||
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
||||||
}
|
}
|
||||||
return split_linear_modules_map
|
return split_linear_modules_map
|
||||||
try:
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
|
|
||||||
except ImportError:
|
|
||||||
BlockDiagonalPaddedKeysMask = None
|
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlock(nn.Module):
|
class MMDoubleStreamBlock(nn.Module):
|
||||||
@ -469,7 +465,7 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
del img_mod, txt_mod
|
del img_mod, txt_mod
|
||||||
x_mod_shape = x_mod.shape
|
x_mod_shape = x_mod.shape
|
||||||
x_mod = x_mod.view(-1, x_mod.shape[-1])
|
x_mod = x_mod.view(-1, x_mod.shape[-1])
|
||||||
chunk_size = int(x_mod_shape[1]/6)
|
chunk_size = int(x_mod.shape[0]/6)
|
||||||
x_chunks = torch.split(x_mod, chunk_size)
|
x_chunks = torch.split(x_mod, chunk_size)
|
||||||
attn = attn.view(-1, attn.shape[-1])
|
attn = attn.view(-1, attn.shape[-1])
|
||||||
attn_chunks =torch.split(attn, chunk_size)
|
attn_chunks =torch.split(attn, chunk_size)
|
||||||
|
|||||||
34
ltx_video/configs/ltxv-13b-0.9.8-dev.yaml
Normal file
34
ltx_video/configs/ltxv-13b-0.9.8-dev.yaml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
pipeline_type: multi-scale
|
||||||
|
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
|
||||||
|
downscale_factor: 0.6666666
|
||||||
|
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
||||||
|
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
||||||
|
decode_timestep: 0.05
|
||||||
|
decode_noise_scale: 0.025
|
||||||
|
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
||||||
|
precision: "bfloat16"
|
||||||
|
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
||||||
|
prompt_enhancement_words_threshold: 120
|
||||||
|
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
||||||
|
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
||||||
|
stochastic_sampling: false
|
||||||
|
|
||||||
|
first_pass:
|
||||||
|
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
|
||||||
|
stg_scale: [0, 0, 4, 4, 4, 2, 1]
|
||||||
|
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
|
||||||
|
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
|
||||||
|
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
|
||||||
|
num_inference_steps: 30
|
||||||
|
skip_final_inference_steps: 3
|
||||||
|
cfg_star_rescale: true
|
||||||
|
|
||||||
|
second_pass:
|
||||||
|
guidance_scale: [1]
|
||||||
|
stg_scale: [1]
|
||||||
|
rescaling_scale: [1]
|
||||||
|
guidance_timesteps: [1.0]
|
||||||
|
skip_block_list: [27]
|
||||||
|
num_inference_steps: 30
|
||||||
|
skip_initial_inference_steps: 17
|
||||||
|
cfg_star_rescale: true
|
||||||
29
ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml
Normal file
29
ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
pipeline_type: multi-scale
|
||||||
|
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
|
||||||
|
downscale_factor: 0.6666666
|
||||||
|
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
||||||
|
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
||||||
|
decode_timestep: 0.05
|
||||||
|
decode_noise_scale: 0.025
|
||||||
|
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
||||||
|
precision: "bfloat16"
|
||||||
|
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
||||||
|
prompt_enhancement_words_threshold: 120
|
||||||
|
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
||||||
|
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
||||||
|
stochastic_sampling: false
|
||||||
|
|
||||||
|
first_pass:
|
||||||
|
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
||||||
|
guidance_scale: 1
|
||||||
|
stg_scale: 0
|
||||||
|
rescaling_scale: 1
|
||||||
|
skip_block_list: [42]
|
||||||
|
|
||||||
|
second_pass:
|
||||||
|
timesteps: [0.9094, 0.7250, 0.4219]
|
||||||
|
guidance_scale: 1
|
||||||
|
stg_scale: 0
|
||||||
|
rescaling_scale: 1
|
||||||
|
skip_block_list: [42]
|
||||||
|
tone_map_compression_ratio: 0.6
|
||||||
@ -149,6 +149,7 @@ class LTXV:
|
|||||||
self,
|
self,
|
||||||
model_filepath: str,
|
model_filepath: str,
|
||||||
text_encoder_filepath: str,
|
text_encoder_filepath: str,
|
||||||
|
model_type, base_model_type,
|
||||||
model_def,
|
model_def,
|
||||||
dtype = torch.bfloat16,
|
dtype = torch.bfloat16,
|
||||||
VAE_dtype = torch.bfloat16,
|
VAE_dtype = torch.bfloat16,
|
||||||
@ -159,24 +160,31 @@ class LTXV:
|
|||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
self.mixed_precision_transformer = mixed_precision_transformer
|
self.mixed_precision_transformer = mixed_precision_transformer
|
||||||
self.model_def = model_def
|
self.model_def = model_def
|
||||||
|
self.model_type = model_type
|
||||||
self.pipeline_config = model_def["LTXV_config"]
|
self.pipeline_config = model_def["LTXV_config"]
|
||||||
|
# ckpt_path ="c:/temp/ltxv-13b-0.9.8-dev.safetensors"
|
||||||
# with safe_open(ckpt_path, framework="pt") as f:
|
# with safe_open(ckpt_path, framework="pt") as f:
|
||||||
# metadata = f.metadata()
|
# metadata = f.metadata()
|
||||||
# config_str = metadata.get("config")
|
# config_str = metadata.get("config")
|
||||||
# configs = json.loads(config_str)
|
# configs = json.loads(config_str)
|
||||||
# allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
# allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
||||||
|
# with open("c:/temp/ltxv_config.json", "w", encoding="utf-8") as writer:
|
||||||
|
# writer.write(json.dumps(configs["transformer"]))
|
||||||
|
# with open("c:/temp/vae_config.json", "w", encoding="utf-8") as writer:
|
||||||
|
# writer.write(json.dumps(configs["vae"]))
|
||||||
# transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
# transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
||||||
# transformer = offload.fast_load_transformers_model("c:/temp/ltxdistilled/diffusion_pytorch_model-00001-of-00006.safetensors", forcedConfigPath="c:/temp/ltxdistilled/config.json")
|
# offload.save_model(transformer, "ckpts/ltxv_0.9.8_13B_bf16.safetensors", config_file_path= "c:/temp/ltxv_config.json")
|
||||||
|
# offload.save_model(transformer, "ckpts/ltxv_0.9.8_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path= "c:/temp/ltxv_config.json")
|
||||||
|
|
||||||
# vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
# vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
||||||
vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder)
|
vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder)
|
||||||
|
# vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.8_VAE.safetensors", modelClass=CausalVideoAutoencoder)
|
||||||
# if VAE_dtype == torch.float16:
|
# if VAE_dtype == torch.float16:
|
||||||
VAE_dtype = torch.bfloat16
|
VAE_dtype = torch.bfloat16
|
||||||
|
|
||||||
vae = vae.to(VAE_dtype)
|
vae = vae.to(VAE_dtype)
|
||||||
vae._model_dtype = VAE_dtype
|
vae._model_dtype = VAE_dtype
|
||||||
# vae = offload.fast_load_transformers_model("vae.safetensors", modelClass=CausalVideoAutoencoder, modelPrefix= "vae", forcedConfigPath="config_vae.json")
|
# offload.save_model(vae, "vae.safetensors", config_file_path="c:/temp/config_vae.json")
|
||||||
# offload.save_model(vae, "vae.safetensors", config_file_path="config_vae.json")
|
|
||||||
|
|
||||||
# model_filepath = "c:/temp/ltxd/ltxv-13b-0.9.7-distilled.safetensors"
|
# model_filepath = "c:/temp/ltxd/ltxv-13b-0.9.7-distilled.safetensors"
|
||||||
transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel)
|
transformer = offload.fast_load_transformers_model(model_filepath, modelClass=Transformer3DModel)
|
||||||
@ -193,6 +201,7 @@ class LTXV:
|
|||||||
# offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json")
|
# offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json")
|
||||||
|
|
||||||
latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval()
|
latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval()
|
||||||
|
# latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.8_spatial_upscaler.safetensors").to("cpu").eval()
|
||||||
latent_upsampler.to(VAE_dtype)
|
latent_upsampler.to(VAE_dtype)
|
||||||
latent_upsampler._model_dtype = VAE_dtype
|
latent_upsampler._model_dtype = VAE_dtype
|
||||||
|
|
||||||
@ -259,6 +268,7 @@ class LTXV:
|
|||||||
image_start = None,
|
image_start = None,
|
||||||
image_end = None,
|
image_end = None,
|
||||||
input_video = None,
|
input_video = None,
|
||||||
|
input_frames = None,
|
||||||
sampling_steps = 50,
|
sampling_steps = 50,
|
||||||
image_cond_noise_scale: float = 0.15,
|
image_cond_noise_scale: float = 0.15,
|
||||||
input_media_path: Optional[str] = None,
|
input_media_path: Optional[str] = None,
|
||||||
@ -272,6 +282,7 @@ class LTXV:
|
|||||||
callback=None,
|
callback=None,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
VAE_tile_size = None,
|
VAE_tile_size = None,
|
||||||
|
apg_switch = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -280,21 +291,33 @@ class LTXV:
|
|||||||
conditioning_strengths = None
|
conditioning_strengths = None
|
||||||
conditioning_media_paths = []
|
conditioning_media_paths = []
|
||||||
conditioning_start_frames = []
|
conditioning_start_frames = []
|
||||||
|
conditioning_control_frames = []
|
||||||
|
prefix_size = 0
|
||||||
if input_video != None:
|
if input_video != None:
|
||||||
conditioning_media_paths.append(input_video)
|
conditioning_media_paths.append(input_video)
|
||||||
conditioning_start_frames.append(0)
|
conditioning_start_frames.append(0)
|
||||||
height, width = input_video.shape[-2:]
|
conditioning_control_frames.append(False)
|
||||||
|
prefix_size, height, width = input_video.shape[-3:]
|
||||||
else:
|
else:
|
||||||
if image_start != None:
|
if image_start != None:
|
||||||
frame_width, frame_height = image_start.size
|
frame_width, frame_height = image_start.size
|
||||||
|
if fit_into_canvas != None:
|
||||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
|
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
|
||||||
conditioning_media_paths.append(image_start)
|
conditioning_media_paths.append(image_start)
|
||||||
conditioning_start_frames.append(0)
|
conditioning_start_frames.append(0)
|
||||||
|
conditioning_control_frames.append(False)
|
||||||
|
prefix_size = 1
|
||||||
if image_end != None:
|
if image_end != None:
|
||||||
conditioning_media_paths.append(image_end)
|
conditioning_media_paths.append(image_end)
|
||||||
conditioning_start_frames.append(frame_num-1)
|
conditioning_start_frames.append(frame_num-1)
|
||||||
|
conditioning_control_frames.append(False)
|
||||||
|
|
||||||
|
if input_frames!= None:
|
||||||
|
conditioning_media_paths.append(input_frames)
|
||||||
|
conditioning_start_frames.append(prefix_size)
|
||||||
|
conditioning_control_frames.append(True)
|
||||||
|
height, width = input_frames.shape[-2:]
|
||||||
|
fit_into_canvas = None
|
||||||
|
|
||||||
if len(conditioning_media_paths) == 0:
|
if len(conditioning_media_paths) == 0:
|
||||||
conditioning_media_paths = None
|
conditioning_media_paths = None
|
||||||
@ -380,6 +403,7 @@ class LTXV:
|
|||||||
conditioning_media_paths=conditioning_media_paths,
|
conditioning_media_paths=conditioning_media_paths,
|
||||||
conditioning_strengths=conditioning_strengths,
|
conditioning_strengths=conditioning_strengths,
|
||||||
conditioning_start_frames=conditioning_start_frames,
|
conditioning_start_frames=conditioning_start_frames,
|
||||||
|
conditioning_control_frames=conditioning_control_frames,
|
||||||
height=height,
|
height=height,
|
||||||
width=width,
|
width=width,
|
||||||
num_frames=frame_num,
|
num_frames=frame_num,
|
||||||
@ -435,6 +459,7 @@ class LTXV:
|
|||||||
mixed_precision=pipeline_config.get("mixed", self.mixed_precision_transformer),
|
mixed_precision=pipeline_config.get("mixed", self.mixed_precision_transformer),
|
||||||
callback=callback,
|
callback=callback,
|
||||||
VAE_tile_size = VAE_tile_size,
|
VAE_tile_size = VAE_tile_size,
|
||||||
|
apg_switch = apg_switch,
|
||||||
device=device,
|
device=device,
|
||||||
# enhance_prompt=enhance_prompt,
|
# enhance_prompt=enhance_prompt,
|
||||||
)
|
)
|
||||||
@ -453,11 +478,29 @@ class LTXV:
|
|||||||
images = images.sub_(0.5).mul_(2).squeeze(0)
|
images = images.sub_(0.5).mul_(2).squeeze(0)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs):
|
||||||
|
map = {
|
||||||
|
"P" : "pose",
|
||||||
|
"D" : "depth",
|
||||||
|
"S" : "canny",
|
||||||
|
}
|
||||||
|
loras = []
|
||||||
|
preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs")
|
||||||
|
lora_file_name = ""
|
||||||
|
for letter, signature in map.items():
|
||||||
|
if letter in video_prompt_type:
|
||||||
|
for file_name in preloadURLs:
|
||||||
|
if signature in file_name:
|
||||||
|
loras += [ os.path.join("ckpts", os.path.basename(file_name))]
|
||||||
|
break
|
||||||
|
loras_mult = [1.] * len(loras)
|
||||||
|
return loras, loras_mult
|
||||||
|
|
||||||
def prepare_conditioning(
|
def prepare_conditioning(
|
||||||
conditioning_media_paths: List[str],
|
conditioning_media_paths: List[str],
|
||||||
conditioning_strengths: List[float],
|
conditioning_strengths: List[float],
|
||||||
conditioning_start_frames: List[int],
|
conditioning_start_frames: List[int],
|
||||||
|
conditioning_control_frames: List[int],
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
@ -480,8 +523,8 @@ def prepare_conditioning(
|
|||||||
A list of ConditioningItem objects.
|
A list of ConditioningItem objects.
|
||||||
"""
|
"""
|
||||||
conditioning_items = []
|
conditioning_items = []
|
||||||
for path, strength, start_frame in zip(
|
for path, strength, start_frame, conditioning_control_frames in zip(
|
||||||
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
conditioning_media_paths, conditioning_strengths, conditioning_start_frames, conditioning_control_frames
|
||||||
):
|
):
|
||||||
if isinstance(path, Image.Image):
|
if isinstance(path, Image.Image):
|
||||||
num_input_frames = orig_num_input_frames = 1
|
num_input_frames = orig_num_input_frames = 1
|
||||||
@ -506,7 +549,7 @@ def prepare_conditioning(
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
just_crop=True,
|
just_crop=True,
|
||||||
)
|
)
|
||||||
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
|
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength, conditioning_control_frames))
|
||||||
return conditioning_items
|
return conditioning_items
|
||||||
|
|
||||||
|
|
||||||
@ -561,3 +604,16 @@ def load_media_file(
|
|||||||
raise Exception("video format not supported")
|
raise Exception("video format not supported")
|
||||||
return media_tensor
|
return media_tensor
|
||||||
|
|
||||||
|
def query_model_def(model_type, model_def):
|
||||||
|
LTXV_config = model_def.get("LTXV_config", "")
|
||||||
|
distilled= "distilled" in LTXV_config
|
||||||
|
model_def_output = {
|
||||||
|
"lock_inference_steps": True,
|
||||||
|
"no_guidance": True,
|
||||||
|
}
|
||||||
|
if distilled:
|
||||||
|
model_def_output.update({
|
||||||
|
"no_negative_prompt" : True,
|
||||||
|
})
|
||||||
|
|
||||||
|
return model_def_output
|
||||||
@ -253,10 +253,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|||||||
if key.startswith("vae.")
|
if key.startswith("vae.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
stats_keys_to_keep = ["per_channel_statistics.std-of-means", "per_channel_statistics.mean-of-means"]
|
||||||
ckpt_state_dict = {
|
ckpt_state_dict = {
|
||||||
key: value
|
key: value
|
||||||
for key, value in state_dict.items()
|
for key, value in state_dict.items()
|
||||||
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) or key in stats_keys_to_keep
|
||||||
}
|
}
|
||||||
|
|
||||||
model_keys = set(name for name, _ in self.named_modules())
|
model_keys = set(name for name, _ in self.named_modules())
|
||||||
@ -280,21 +282,26 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|||||||
|
|
||||||
converted_state_dict[key] = value
|
converted_state_dict[key] = value
|
||||||
|
|
||||||
|
# data_dict = {
|
||||||
|
# key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
|
||||||
|
# for key, value in state_dict.items()
|
||||||
|
# if key in stats_keys_to_keep
|
||||||
|
# }
|
||||||
|
for key in stats_keys_to_keep:
|
||||||
|
if key in converted_state_dict: # happens only in the original vae sd
|
||||||
|
v = converted_state_dict.pop(key)
|
||||||
|
converted_state_dict[key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX).replace("-", "_")] = v
|
||||||
|
|
||||||
a,b = super().load_state_dict(converted_state_dict, strict=strict, assign=assign)
|
a,b = super().load_state_dict(converted_state_dict, strict=strict, assign=assign)
|
||||||
|
|
||||||
data_dict = {
|
# if len(data_dict) > 0:
|
||||||
key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
|
# self.register_buffer("std_of_means", data_dict["std-of-means"],)
|
||||||
for key, value in state_dict.items()
|
# self.register_buffer(
|
||||||
if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
# "mean_of_means",
|
||||||
}
|
# data_dict.get(
|
||||||
if len(data_dict) > 0:
|
# "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
||||||
self.register_buffer("std_of_means", data_dict["std-of-means"],)
|
# ),
|
||||||
self.register_buffer(
|
# )
|
||||||
"mean_of_means",
|
|
||||||
data_dict.get(
|
|
||||||
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return a, b
|
return a, b
|
||||||
|
|
||||||
def last_layer(self):
|
def last_layer(self):
|
||||||
|
|||||||
@ -44,14 +44,19 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|||||||
self.per_channel_statistics = nn.Module()
|
self.per_channel_statistics = nn.Module()
|
||||||
std_of_means = torch.zeros( (128,), dtype= torch.bfloat16)
|
std_of_means = torch.zeros( (128,), dtype= torch.bfloat16)
|
||||||
|
|
||||||
self.per_channel_statistics.register_buffer("std-of-means", std_of_means)
|
# self.per_channel_statistics.register_buffer("std-of-means", std_of_means)
|
||||||
self.per_channel_statistics.register_buffer(
|
# self.per_channel_statistics.register_buffer(
|
||||||
"mean-of-means",
|
# "mean-of-means",
|
||||||
|
# torch.zeros_like(std_of_means)
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.register_buffer("std_of_means", std_of_means)
|
||||||
|
self.register_buffer(
|
||||||
|
"mean_of_means",
|
||||||
torch.zeros_like(std_of_means)
|
torch.zeros_like(std_of_means)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# pass init params to Encoder
|
# pass init params to Encoder
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.use_quant_conv = use_quant_conv
|
self.use_quant_conv = use_quant_conv
|
||||||
|
|||||||
@ -120,6 +120,48 @@ ASPECT_RATIO_512_BIN = {
|
|||||||
"4.0": [1024.0, 256.0],
|
"4.0": [1024.0, 256.0],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MomentumBuffer:
|
||||||
|
def __init__(self, momentum: float):
|
||||||
|
self.momentum = momentum
|
||||||
|
self.running_average = 0
|
||||||
|
|
||||||
|
def update(self, update_value: torch.Tensor):
|
||||||
|
new_average = self.momentum * self.running_average
|
||||||
|
self.running_average = update_value + new_average
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def project(
|
||||||
|
v0: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
v1: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
):
|
||||||
|
dtype = v0.dtype
|
||||||
|
v0, v1 = v0.double(), v1.double()
|
||||||
|
v1 = torch.nn.functional.normalize(v1, dim=[-2, -1])
|
||||||
|
v0_parallel = (v0 * v1).sum(dim=[-2, -1], keepdim=True) * v1
|
||||||
|
v0_orthogonal = v0 - v0_parallel
|
||||||
|
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_projected_guidance(
|
||||||
|
diff: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
pred_cond: torch.Tensor, # [B, C, T, H, W]
|
||||||
|
momentum_buffer: MomentumBuffer = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
norm_threshold: float = 55,
|
||||||
|
):
|
||||||
|
if momentum_buffer is not None:
|
||||||
|
momentum_buffer.update(diff)
|
||||||
|
diff = momentum_buffer.running_average
|
||||||
|
if norm_threshold > 0:
|
||||||
|
ones = torch.ones_like(diff)
|
||||||
|
diff_norm = diff.norm(p=2, dim=[-2, -1], keepdim=True)
|
||||||
|
print(f"diff_norm: {diff_norm}")
|
||||||
|
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||||
|
diff = diff * scale_factor
|
||||||
|
diff_parallel, diff_orthogonal = project(diff, pred_cond)
|
||||||
|
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||||
|
return normalized_update
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
def retrieve_timesteps(
|
def retrieve_timesteps(
|
||||||
@ -215,6 +257,7 @@ class ConditioningItem:
|
|||||||
media_item: torch.Tensor
|
media_item: torch.Tensor
|
||||||
media_frame_number: int
|
media_frame_number: int
|
||||||
conditioning_strength: float
|
conditioning_strength: float
|
||||||
|
control_frames: bool = False
|
||||||
media_x: Optional[int] = None
|
media_x: Optional[int] = None
|
||||||
media_y: Optional[int] = None
|
media_y: Optional[int] = None
|
||||||
|
|
||||||
@ -796,6 +839,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
text_encoder_max_tokens: int = 256,
|
text_encoder_max_tokens: int = 256,
|
||||||
stochastic_sampling: bool = False,
|
stochastic_sampling: bool = False,
|
||||||
media_items: Optional[torch.Tensor] = None,
|
media_items: Optional[torch.Tensor] = None,
|
||||||
|
tone_map_compression_ratio: float = 0.0,
|
||||||
strength: Optional[float] = 1.0,
|
strength: Optional[float] = 1.0,
|
||||||
skip_initial_inference_steps: int = 0,
|
skip_initial_inference_steps: int = 0,
|
||||||
skip_final_inference_steps: int = 0,
|
skip_final_inference_steps: int = 0,
|
||||||
@ -803,6 +847,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
pass_no: int = -1,
|
pass_no: int = -1,
|
||||||
ltxv_model = None,
|
ltxv_model = None,
|
||||||
callback=None,
|
callback=None,
|
||||||
|
apg_switch = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[ImagePipelineOutput, Tuple]:
|
) -> Union[ImagePipelineOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
@ -876,6 +921,8 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
media_items ('torch.Tensor', *optional*):
|
media_items ('torch.Tensor', *optional*):
|
||||||
The input media item used for image-to-image / video-to-video.
|
The input media item used for image-to-image / video-to-video.
|
||||||
When provided, they will be noised according to 'strength' and then fully denoised.
|
When provided, they will be noised according to 'strength' and then fully denoised.
|
||||||
|
tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
|
||||||
|
If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
|
||||||
strength ('floaty', *optional* defaults to 1.0):
|
strength ('floaty', *optional* defaults to 1.0):
|
||||||
The editing level in image-to-image / video-to-video. The provided input will be noised
|
The editing level in image-to-image / video-to-video. The provided input will be noised
|
||||||
to this level.
|
to this level.
|
||||||
@ -1077,7 +1124,10 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
init_latents = latents.clone() # Used for image_cond_noise_update
|
init_latents = latents.clone() # Used for image_cond_noise_update
|
||||||
|
if conditioning_items is not None and len(conditioning_items) > 0 and not conditioning_items[0].control_frames and conditioning_items[0].media_frame_number == 0:
|
||||||
|
prefix_latent_frames = (conditioning_items[0].media_item.shape[2] - 1)// 8 + 1
|
||||||
|
else:
|
||||||
|
prefix_latent_frames = 0
|
||||||
# pixel_coords = torch.cat([pixel_coords] * num_conds)
|
# pixel_coords = torch.cat([pixel_coords] * num_conds)
|
||||||
orig_conditioning_mask = conditioning_mask
|
orig_conditioning_mask = conditioning_mask
|
||||||
if conditioning_mask is not None and is_video:
|
if conditioning_mask is not None and is_video:
|
||||||
@ -1096,6 +1146,12 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
cfg_star_rescale = True
|
cfg_star_rescale = True
|
||||||
|
|
||||||
|
if apg_switch != 0:
|
||||||
|
apg_momentum = -0.75
|
||||||
|
apg_norm_threshold = 55
|
||||||
|
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
|
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
|
|
||||||
|
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True, override_num_inference_steps = num_inference_steps, pass_no =pass_no)
|
callback(-1, None, True, override_num_inference_steps = num_inference_steps, pass_no =pass_no)
|
||||||
@ -1186,6 +1242,14 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
)[-2:]
|
)[-2:]
|
||||||
if do_classifier_free_guidance and guidance_scale[i] !=0 and guidance_scale[i] !=1 :
|
if do_classifier_free_guidance and guidance_scale[i] !=0 and guidance_scale[i] !=1 :
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2]
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2]
|
||||||
|
|
||||||
|
if apg_switch != 0:
|
||||||
|
noise_pred = noise_pred_text + (guidance_scale[i] - 1) * adaptive_projected_guidance(noise_pred_text - noise_pred_uncond,
|
||||||
|
noise_pred_text,
|
||||||
|
momentum_buffer=text_momentumbuffer,
|
||||||
|
norm_threshold=apg_norm_threshold)
|
||||||
|
|
||||||
|
else:
|
||||||
if cfg_star_rescale:
|
if cfg_star_rescale:
|
||||||
batch_size = noise_pred_text.shape[0]
|
batch_size = noise_pred_text.shape[0]
|
||||||
|
|
||||||
@ -1242,7 +1306,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
# callback(i, None, False, pass_no =pass_no)
|
# callback(i, None, False, pass_no =pass_no)
|
||||||
preview_latents= latents.squeeze(0).transpose(0, 1)
|
preview_latents= latents[:, num_cond_latents:].squeeze(0).transpose(0, 1)
|
||||||
preview_latents= preview_latents.reshape(preview_latents.shape[0], latent_num_frames, latent_height, latent_width)
|
preview_latents= preview_latents.reshape(preview_latents.shape[0], latent_num_frames, latent_height, latent_width)
|
||||||
callback(i, preview_latents, False, pass_no =pass_no)
|
callback(i, preview_latents, False, pass_no =pass_no)
|
||||||
preview_latents = None
|
preview_latents = None
|
||||||
@ -1285,8 +1349,9 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decode_timestep = None
|
decode_timestep = None
|
||||||
torch.save(latents, "lala.pt")
|
# torch.save(latents, "lala.pt")
|
||||||
# latents = torch.load("lala.pt")
|
# latents = torch.load("lala.pt")
|
||||||
|
latents = self.tone_map_latents(latents, tone_map_compression_ratio, start = prefix_latent_frames)
|
||||||
image = vae_decode(
|
image = vae_decode(
|
||||||
latents,
|
latents,
|
||||||
self.vae,
|
self.vae,
|
||||||
@ -1306,6 +1371,57 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tone_map_latents(
|
||||||
|
latents: torch.Tensor,
|
||||||
|
compression: float,
|
||||||
|
start: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range
|
||||||
|
in a perceptually smooth way using a sigmoid-based compression.
|
||||||
|
|
||||||
|
This is useful for regularizing high-variance latents or for conditioning outputs
|
||||||
|
during generation, especially when controlling dynamic behavior with a `compression` factor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
latents : torch.Tensor
|
||||||
|
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
|
||||||
|
compression : float
|
||||||
|
Compression strength in the range [0, 1].
|
||||||
|
- 0.0: No tone-mapping (identity transform)
|
||||||
|
- 1.0: Full compression effect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The tone-mapped latent tensor of the same shape as input.
|
||||||
|
"""
|
||||||
|
if compression ==0:
|
||||||
|
return latents
|
||||||
|
if not (0 <= compression <= 1):
|
||||||
|
raise ValueError("Compression must be in the range [0, 1]")
|
||||||
|
|
||||||
|
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
|
||||||
|
scale_factor = compression * 0.75
|
||||||
|
abs_latents = torch.abs(latents)
|
||||||
|
|
||||||
|
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
|
||||||
|
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
|
||||||
|
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||||
|
# DeepBeepMeep special touch to allow a smooth transition with tone mapping
|
||||||
|
if start > 0:
|
||||||
|
gradient_tensor = torch.linspace(0, 1, latents.shape[2])
|
||||||
|
gradient_tensor = gradient_tensor ** 0.5
|
||||||
|
gradient_tensor = gradient_tensor[ None, None, :, None, None ]
|
||||||
|
sigmoid_term *= gradient_tensor
|
||||||
|
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||||
|
|
||||||
|
|
||||||
|
filtered = latents * scales
|
||||||
|
return filtered
|
||||||
|
|
||||||
def denoising_step(
|
def denoising_step(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
@ -1405,18 +1521,18 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
media_item = conditioning_item.media_item
|
media_item = conditioning_item.media_item
|
||||||
media_frame_number = conditioning_item.media_frame_number
|
media_frame_number = conditioning_item.media_frame_number
|
||||||
strength = conditioning_item.conditioning_strength
|
strength = conditioning_item.conditioning_strength
|
||||||
|
control_frames = conditioning_item.control_frames
|
||||||
assert media_item.ndim == 5 # (b, c, f, h, w)
|
assert media_item.ndim == 5 # (b, c, f, h, w)
|
||||||
b, c, n_frames, h, w = media_item.shape
|
b, c, n_frames, h, w = media_item.shape
|
||||||
assert (
|
assert (
|
||||||
height == h and width == w
|
height == h and width == w
|
||||||
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
||||||
assert n_frames % 8 == 1
|
# assert n_frames % 8 == 1
|
||||||
assert (
|
# assert (
|
||||||
media_frame_number >= 0
|
# media_frame_number >= 0
|
||||||
and media_frame_number + n_frames <= num_frames
|
# and media_frame_number + n_frames <= num_frames
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Encode the provided conditioning media item
|
|
||||||
media_item_latents = vae_encode(
|
media_item_latents = vae_encode(
|
||||||
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
||||||
self.vae,
|
self.vae,
|
||||||
@ -1424,7 +1540,33 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|||||||
).to(dtype=init_latents.dtype)
|
).to(dtype=init_latents.dtype)
|
||||||
|
|
||||||
# Handle the different conditioning cases
|
# Handle the different conditioning cases
|
||||||
if media_frame_number == 0:
|
if control_frames:
|
||||||
|
#control frames sequence is assumed to start one frame before the actual location so that we can properly insert the prefix latent
|
||||||
|
if media_frame_number > 0:
|
||||||
|
media_frame_number = media_frame_number -1
|
||||||
|
media_item_latents, media_latent_coords = self.patchifier.patchify(
|
||||||
|
latents=media_item_latents
|
||||||
|
)
|
||||||
|
media_pixel_coords = latent_to_pixel_coords(
|
||||||
|
media_latent_coords,
|
||||||
|
self.vae,
|
||||||
|
causal_fix=self.transformer.config.causal_temporal_positioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
media_conditioning_mask = torch.full(
|
||||||
|
media_item_latents.shape[:2],
|
||||||
|
strength,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=init_latents.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the frame numbers to match the target frame number
|
||||||
|
media_pixel_coords[:, 0] += media_frame_number
|
||||||
|
extra_conditioning_num_latents += media_item_latents.shape[1]
|
||||||
|
extra_conditioning_latents.append(media_item_latents)
|
||||||
|
extra_conditioning_pixel_coords.append(media_pixel_coords)
|
||||||
|
extra_conditioning_mask.append(media_conditioning_mask)
|
||||||
|
elif media_frame_number == 0:
|
||||||
# Get the target spatial position of the latent conditioning item
|
# Get the target spatial position of the latent conditioning item
|
||||||
media_item_latents, l_x, l_y = self._get_latent_spatial_position(
|
media_item_latents, l_x, l_y = self._get_latent_spatial_position(
|
||||||
media_item_latents,
|
media_item_latents,
|
||||||
|
|||||||
@ -23,6 +23,23 @@ Note any changes or sudden events
|
|||||||
Do not exceed the 150 word limit!
|
Do not exceed the 150 word limit!
|
||||||
Output the enhanced prompt only.
|
Output the enhanced prompt only.
|
||||||
"""
|
"""
|
||||||
|
T2I_VISUAL_PROMPT = """You are an expert visual artist and photographer with award-winning compositions. When writing prompts based on the user input, focus on detailed, precise descriptions of visual elements and composition.
|
||||||
|
Include specific poses, appearances, framing, and environmental details - all in a single flowing paragraph.
|
||||||
|
Start directly with the main subject, and keep descriptions literal and precise.
|
||||||
|
Think like a photographer describing the perfect shot.
|
||||||
|
Do not change the user input intent, just enhance it.
|
||||||
|
Keep within 150 words.
|
||||||
|
For best results, build your prompts using this structure:
|
||||||
|
Start with main subject and pose in a single sentence
|
||||||
|
Add specific details about expressions and positioning
|
||||||
|
Describe character/object appearances precisely
|
||||||
|
Include background and environment details
|
||||||
|
Specify framing, composition and perspective
|
||||||
|
Describe lighting, colors, and mood
|
||||||
|
Note any atmospheric or stylistic elements
|
||||||
|
Do not exceed the 150 word limit!
|
||||||
|
Output the enhanced prompt only.
|
||||||
|
"""
|
||||||
|
|
||||||
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
||||||
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
||||||
@ -43,6 +60,24 @@ Do not exceed the 150 word limit!
|
|||||||
Output the enhanced prompt only.
|
Output the enhanced prompt only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
I2I_VISUAL_PROMPT = """You are an expert visual artist and photographer with award-winning compositions. When writing prompts based on the user input, focus on detailed, precise descriptions of visual elements and composition.
|
||||||
|
Include specific poses, appearances, framing, and environmental details - all in a single flowing paragraph.
|
||||||
|
Start directly with the main subject, and keep descriptions literal and precise.
|
||||||
|
Think like a photographer describing the perfect shot.
|
||||||
|
Do not change the user input intent, just enhance it.
|
||||||
|
Keep within 150 words.
|
||||||
|
For best results, build your prompts using this structure:
|
||||||
|
Start with main subject and pose in a single sentence
|
||||||
|
Add specific details about expressions and positioning
|
||||||
|
Describe character/object appearances precisely
|
||||||
|
Include background and environment details
|
||||||
|
Specify framing, composition and perspective
|
||||||
|
Describe lighting, colors, and mood
|
||||||
|
Note any atmospheric or stylistic elements
|
||||||
|
Do not exceed the 150 word limit!
|
||||||
|
Output the enhanced prompt only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_pil(tensor):
|
def tensor_to_pil(tensor):
|
||||||
# Ensure tensor is in range [-1, 1]
|
# Ensure tensor is in range [-1, 1]
|
||||||
@ -68,6 +103,7 @@ def generate_cinematic_prompt(
|
|||||||
prompt_enhancer_tokenizer,
|
prompt_enhancer_tokenizer,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
images: Optional[List] = None,
|
images: Optional[List] = None,
|
||||||
|
video_prompt= True,
|
||||||
max_new_tokens: int = 256,
|
max_new_tokens: int = 256,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
prompts = [prompt] if isinstance(prompt, str) else prompt
|
prompts = [prompt] if isinstance(prompt, str) else prompt
|
||||||
@ -78,7 +114,7 @@ def generate_cinematic_prompt(
|
|||||||
prompt_enhancer_tokenizer,
|
prompt_enhancer_tokenizer,
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
T2V_CINEMATIC_PROMPT,
|
T2V_CINEMATIC_PROMPT if video_prompt else T2I_VISUAL_PROMPT,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -90,7 +126,7 @@ def generate_cinematic_prompt(
|
|||||||
prompts,
|
prompts,
|
||||||
images,
|
images,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
I2V_CINEMATIC_PROMPT,
|
I2V_CINEMATIC_PROMPT if video_prompt else I2I_VISUAL_PROMPT,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|||||||
151
preprocessing/canny.py
Normal file
151
preprocessing/canny.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
norm_layer = nn.InstanceNorm2d
|
||||||
|
|
||||||
|
def convert_to_torch(image):
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image = torch.from_numpy(np.array(image)).float()
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
image = image.clone()
|
||||||
|
elif isinstance(image, np.ndarray):
|
||||||
|
image = torch.from_numpy(image.copy()).float()
|
||||||
|
else:
|
||||||
|
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
||||||
|
return image
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_features):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
conv_block = [
|
||||||
|
nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(in_features, in_features, 3),
|
||||||
|
norm_layer(in_features),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(in_features, in_features, 3),
|
||||||
|
norm_layer(in_features)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_block = nn.Sequential(*conv_block)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.conv_block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ContourInference(nn.Module):
|
||||||
|
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
||||||
|
super(ContourInference, self).__init__()
|
||||||
|
|
||||||
|
# Initial convolution block
|
||||||
|
model0 = [
|
||||||
|
nn.ReflectionPad2d(3),
|
||||||
|
nn.Conv2d(input_nc, 64, 7),
|
||||||
|
norm_layer(64),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
self.model0 = nn.Sequential(*model0)
|
||||||
|
|
||||||
|
# Downsampling
|
||||||
|
model1 = []
|
||||||
|
in_features = 64
|
||||||
|
out_features = in_features * 2
|
||||||
|
for _ in range(2):
|
||||||
|
model1 += [
|
||||||
|
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
||||||
|
norm_layer(out_features),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
in_features = out_features
|
||||||
|
out_features = in_features * 2
|
||||||
|
self.model1 = nn.Sequential(*model1)
|
||||||
|
|
||||||
|
model2 = []
|
||||||
|
# Residual blocks
|
||||||
|
for _ in range(n_residual_blocks):
|
||||||
|
model2 += [ResidualBlock(in_features)]
|
||||||
|
self.model2 = nn.Sequential(*model2)
|
||||||
|
|
||||||
|
# Upsampling
|
||||||
|
model3 = []
|
||||||
|
out_features = in_features // 2
|
||||||
|
for _ in range(2):
|
||||||
|
model3 += [
|
||||||
|
nn.ConvTranspose2d(in_features,
|
||||||
|
out_features,
|
||||||
|
3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1),
|
||||||
|
norm_layer(out_features),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
in_features = out_features
|
||||||
|
out_features = in_features // 2
|
||||||
|
self.model3 = nn.Sequential(*model3)
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
|
||||||
|
if sigmoid:
|
||||||
|
model4 += [nn.Sigmoid()]
|
||||||
|
|
||||||
|
self.model4 = nn.Sequential(*model4)
|
||||||
|
|
||||||
|
def forward(self, x, cond=None):
|
||||||
|
out = self.model0(x)
|
||||||
|
out = self.model1(out)
|
||||||
|
out = self.model2(out)
|
||||||
|
out = self.model3(out)
|
||||||
|
out = self.model4(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CannyAnnotator:
|
||||||
|
def __init__(self, cfg, device=None):
|
||||||
|
input_nc = cfg.get('INPUT_NC', 3)
|
||||||
|
output_nc = cfg.get('OUTPUT_NC', 1)
|
||||||
|
n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3)
|
||||||
|
sigmoid = cfg.get('SIGMOID', True)
|
||||||
|
pretrained_model = cfg['PRETRAINED_MODEL']
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
||||||
|
self.model = ContourInference(input_nc, output_nc, n_residual_blocks,
|
||||||
|
sigmoid)
|
||||||
|
self.model.load_state_dict(torch.load(pretrained_model, weights_only=True))
|
||||||
|
self.model = self.model.eval().requires_grad_(False).to(self.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.inference_mode()
|
||||||
|
@torch.autocast('cuda', enabled=False)
|
||||||
|
def forward(self, image):
|
||||||
|
is_batch = False if len(image.shape) == 3 else True
|
||||||
|
image = convert_to_torch(image)
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = rearrange(image, 'h w c -> 1 c h w')
|
||||||
|
image = image.float().div(255).to(self.device)
|
||||||
|
contour_map = self.model(image)
|
||||||
|
contour_map = (contour_map.squeeze(dim=1) * 255.0).clip(
|
||||||
|
0, 255).cpu().numpy().astype(np.uint8)
|
||||||
|
contour_map = contour_map[..., None].repeat(3, -1)
|
||||||
|
contour_map = 255 - contour_map #.where( image >= 127.5,0,1)
|
||||||
|
contour_map[ contour_map > 127.5] = 255
|
||||||
|
contour_map[ contour_map <= 127.5] = 0
|
||||||
|
if not is_batch:
|
||||||
|
contour_map = contour_map.squeeze()
|
||||||
|
return contour_map
|
||||||
|
|
||||||
|
|
||||||
|
class CannyVideoAnnotator(CannyAnnotator):
|
||||||
|
def forward(self, frames):
|
||||||
|
ret_frames = []
|
||||||
|
for frame in frames:
|
||||||
|
anno_frame = super().forward(np.array(frame))
|
||||||
|
ret_frames.append(anno_frame)
|
||||||
|
return ret_frames
|
||||||
@ -10,13 +10,14 @@ from PIL import Image
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from .tools.painter import mask_painter
|
from .tools.painter import mask_painter
|
||||||
from .tools.interact_tools import SamControler
|
from .tools.interact_tools import SamControler
|
||||||
from .tools.misc import get_device
|
from .tools.misc import get_device
|
||||||
from .tools.download_util import load_file_from_url
|
from .tools.download_util import load_file_from_url
|
||||||
|
from segment_anything.modeling.image_encoder import window_partition, window_unpartition, get_rel_pos, Block as image_encoder_block
|
||||||
from .utils.get_default_model import get_matanyone_model
|
from .utils.get_default_model import get_matanyone_model
|
||||||
from .matanyone.inference.inference_core import InferenceCore
|
from .matanyone.inference.inference_core import InferenceCore
|
||||||
from .matanyone_wrapper import matanyone
|
from .matanyone_wrapper import matanyone
|
||||||
@ -83,8 +84,11 @@ def get_frames_from_image(image_input, image_state):
|
|||||||
"fps": None
|
"fps": None
|
||||||
}
|
}
|
||||||
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
||||||
|
set_image_encoder_patch()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
model.samcontroler.sam_controler.reset_image()
|
model.samcontroler.sam_controler.reset_image()
|
||||||
model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
|
model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return image_state, image_info, image_state["origin_images"][0], \
|
return image_state, image_info, image_state["origin_images"][0], \
|
||||||
gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
||||||
gr.update(visible=True), gr.update(visible=True), \
|
gr.update(visible=True), gr.update(visible=True), \
|
||||||
@ -163,8 +167,11 @@ def get_frames_from_video(video_input, video_state):
|
|||||||
"audio": audio_path
|
"audio": audio_path
|
||||||
}
|
}
|
||||||
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
|
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
|
||||||
|
set_image_encoder_patch()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
model.samcontroler.sam_controler.reset_image()
|
model.samcontroler.sam_controler.reset_image()
|
||||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return video_state, video_info, video_state["origin_images"][0], \
|
return video_state, video_info, video_state["origin_images"][0], \
|
||||||
gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
||||||
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
||||||
@ -203,6 +210,70 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
|||||||
|
|
||||||
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
||||||
|
|
||||||
|
|
||||||
|
def patched_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
def split_mlp(mlp, x, divide = 4):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
chunk_size = int(x.shape[0]/divide)
|
||||||
|
x_chunks = torch.split(x, chunk_size)
|
||||||
|
for i, x_chunk in enumerate(x_chunks):
|
||||||
|
mlp_chunk = mlp.lin1(x_chunk)
|
||||||
|
mlp_chunk = mlp.act(mlp_chunk)
|
||||||
|
x_chunk[...] = mlp.lin2(mlp_chunk)
|
||||||
|
return x.reshape(x_shape)
|
||||||
|
|
||||||
|
def get_decomposed_rel_pos( q, rel_pos_h, rel_pos_w, q_size, k_size) -> torch.Tensor:
|
||||||
|
q_h, q_w = q_size
|
||||||
|
k_h, k_w = k_size
|
||||||
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||||
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||||
|
B, _, dim = q.shape
|
||||||
|
r_q = q.reshape(B, q_h, q_w, dim)
|
||||||
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||||
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||||
|
attn = torch.zeros(B, q_h, q_w, k_h, k_w, dtype=q.dtype, device=q.device)
|
||||||
|
attn += rel_h[:, :, :, :, None]
|
||||||
|
attn += rel_w[:, :, :, None, :]
|
||||||
|
return attn.view(B, q_h * q_w, k_h * k_w)
|
||||||
|
|
||||||
|
def pay_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, H, W, _ = x.shape
|
||||||
|
# qkv with shape (3, B, nHead, H * W, C)
|
||||||
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
# q, k, v with shape (B * nHead, H * W, C)
|
||||||
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||||
|
attn_mask = None
|
||||||
|
if self.use_rel_pos:
|
||||||
|
attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
|
||||||
|
del q, k, v, attn_mask
|
||||||
|
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
# Window partition
|
||||||
|
if self.window_size > 0:
|
||||||
|
H, W = x.shape[1], x.shape[2]
|
||||||
|
x, pad_hw = window_partition(x, self.window_size)
|
||||||
|
|
||||||
|
x = pay_attention(self.attn,x)
|
||||||
|
# Reverse window partition
|
||||||
|
if self.window_size > 0:
|
||||||
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||||
|
x += shortcut
|
||||||
|
shortcut[...] = self.norm2(x)
|
||||||
|
# x += self.mlp(shortcut)
|
||||||
|
x += split_mlp(self.mlp, shortcut)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def set_image_encoder_patch():
|
||||||
|
if not hasattr(image_encoder_block, "patched"):
|
||||||
|
image_encoder_block.forward = patched_forward
|
||||||
|
image_encoder_block.patched = True
|
||||||
|
|
||||||
# use sam to get the mask
|
# use sam to get the mask
|
||||||
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): #
|
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): #
|
||||||
"""
|
"""
|
||||||
@ -218,7 +289,9 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|||||||
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
||||||
interactive_state["negative_click_times"] += 1
|
interactive_state["negative_click_times"] += 1
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
# prompt for sam model
|
# prompt for sam model
|
||||||
|
set_image_encoder_patch()
|
||||||
model.samcontroler.sam_controler.reset_image()
|
model.samcontroler.sam_controler.reset_image()
|
||||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
||||||
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
||||||
@ -233,6 +306,7 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|||||||
video_state["logits"][video_state["select_frame_number"]] = logit
|
video_state["logits"][video_state["select_frame_number"]] = logit
|
||||||
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return painted_image, video_state, interactive_state
|
return painted_image, video_state, interactive_state
|
||||||
|
|
||||||
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
||||||
@ -313,7 +387,9 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
|
|||||||
# operation error
|
# operation error
|
||||||
if len(np.unique(template_mask))==1:
|
if len(np.unique(template_mask))==1:
|
||||||
template_mask[0][0]=1
|
template_mask[0][0]=1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
|
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
foreground_mat = False
|
foreground_mat = False
|
||||||
@ -376,7 +452,9 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask
|
|||||||
# operation error
|
# operation error
|
||||||
if len(np.unique(template_mask))==1:
|
if len(np.unique(template_mask))==1:
|
||||||
template_mask[0][0]=1
|
template_mask[0][0]=1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
|
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
output_frames = []
|
output_frames = []
|
||||||
foreground_mat = matting_type == "Foreground"
|
foreground_mat = matting_type == "Foreground"
|
||||||
if not foreground_mat:
|
if not foreground_mat:
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from .utils.vace_preprocessor import VaceVideoProcessor
|
|||||||
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
||||||
from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions
|
from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions
|
||||||
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance
|
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance
|
||||||
|
from mmgp import safetensors2
|
||||||
|
|
||||||
def optimized_scale(positive_flat, negative_flat):
|
def optimized_scale(positive_flat, negative_flat):
|
||||||
|
|
||||||
@ -101,14 +102,13 @@ class WanAny2V:
|
|||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
# xmodel_filename = "c:/ml/multitalk/multitalk.safetensors"
|
xmodel_filename = "c:/temp/wan2.1_text2video_1.3B_bf16.safetensors"
|
||||||
# config_filename= "configs/multitalk.json"
|
# config_filename= "configs/t2v_1.3B.json"
|
||||||
# import json
|
# import json
|
||||||
# with open(config_filename, 'r', encoding='utf-8') as f:
|
# with open(config_filename, 'r', encoding='utf-8') as f:
|
||||||
# config = json.load(f)
|
# config = json.load(f)
|
||||||
# from mmgp import safetensors2
|
|
||||||
# sd = safetensors2.torch_load_file(xmodel_filename)
|
# sd = safetensors2.torch_load_file(xmodel_filename)
|
||||||
# model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors"
|
# model_filename = "c:/temp/vace1_3B.safetensors"
|
||||||
base_config_file = f"configs/{base_model_type}.json"
|
base_config_file = f"configs/{base_model_type}.json"
|
||||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
||||||
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
|
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
|
||||||
@ -118,12 +118,7 @@ class WanAny2V:
|
|||||||
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||||
offload.change_dtype(self.model, dtype, True)
|
offload.change_dtype(self.model, dtype, True)
|
||||||
# offload.save_model(self.model, "flf2v_720p.safetensors", config_file_path=base_config_file)
|
# offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
|
||||||
# offload.save_model(self.model, "flf2v_quanto_int8_fp16_720p.safetensors", do_quantize= True, config_file_path=base_config_file)
|
|
||||||
# offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd)
|
|
||||||
|
|
||||||
# offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file)
|
|
||||||
# offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file)
|
|
||||||
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
if save_quantized:
|
if save_quantized:
|
||||||
@ -867,3 +862,6 @@ class WanAny2V:
|
|||||||
target = modules_dict[f"blocks.{model_layer}"]
|
target = modules_dict[f"blocks.{model_layer}"]
|
||||||
setattr(target, "vace", module )
|
setattr(target, "vace", module )
|
||||||
delattr(model, "vace_blocks")
|
delattr(model, "vace_blocks")
|
||||||
|
|
||||||
|
def query_model_def(model_type, model_def):
|
||||||
|
return None
|
||||||
@ -446,3 +446,6 @@ class DTT2V:
|
|||||||
videos = videos[0] # return only first video
|
videos = videos[0] # return only first video
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
|
def query_model_def(model_type, model_def):
|
||||||
|
return None
|
||||||
@ -552,7 +552,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
y_shape = y.shape
|
y_shape = y.shape
|
||||||
y = y.view(-1, y_shape[-1])
|
y = y.view(-1, y_shape[-1])
|
||||||
chunk_size = int(y_shape[1]/2.7)
|
chunk_size = int(y.shape[0]/2.7)
|
||||||
chunks =torch.split(y, chunk_size)
|
chunks =torch.split(y, chunk_size)
|
||||||
for y_chunk in chunks:
|
for y_chunk in chunks:
|
||||||
mlp_chunk = ffn(y_chunk)
|
mlp_chunk = ffn(y_chunk)
|
||||||
|
|||||||
@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
|
|||||||
|
|
||||||
_, seq_lens, heads, _ = visual_q.shape
|
_, seq_lens, heads, _ = visual_q.shape
|
||||||
class_num, _ = ref_target_masks.shape
|
class_num, _ = ref_target_masks.shape
|
||||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype)
|
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device)
|
||||||
|
|
||||||
split_chunk = heads // split_num
|
split_chunk = heads // split_num
|
||||||
|
|
||||||
|
|||||||
@ -99,7 +99,7 @@ def get_video_info(video_path):
|
|||||||
cap = cv2.VideoCapture(video_path)
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
# Get FPS
|
# Get FPS
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = round(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
|
||||||
# Get resolution
|
# Get resolution
|
||||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user