Merge branch 'main' into feature_add-cuda-docker-runner

This commit is contained in:
Ciprian Mandache 2025-08-30 06:56:10 +03:00
commit 5940aa023f
80 changed files with 5963 additions and 725 deletions

View File

@ -22,6 +22,28 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
Also in the news:
- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
- *Film Grain* post processing to add a vintage look at your video
- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me

13
defaults/ReadMe.txt Normal file
View File

@ -0,0 +1,13 @@
Please dot not modify any file in this Folder.
If you want to change a property of a default model, copy the corrresponding model file in the ./finetunes folder and modify the properties you want to change in the new file.
If a property is not in the new file, it will be inherited automatically from the default file that matches the same name file.
For instance to hide a model:
{
"model":
{
"visible": false
}
}

View File

@ -1,14 +1,14 @@
{
"model":
{
"name": "First Last Frame to Video 720p (FLF2V)14B",
"name": "First Last Frame to Video 720p (FLF2V) 14B",
"architecture" : "flf2v_720p",
"visible" : false,
"visible" : true,
"description": "The First Last Frame 2 Video model is the official model Image 2 Video model that supports Start and End frames.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors"
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mfp16_int8.safetensors"
],
"auto_quantize": true
},

View File

@ -0,0 +1,16 @@
{
"model": {
"name": "Flux Dev Kontext 12B",
"architecture": "flux_dev_kontext",
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
]
},
"prompt": "add a hat",
"resolution": "1280x720",
"video_length": 1
}

13
defaults/fun_inp.json Normal file
View File

@ -0,0 +1,13 @@
{
"model":
{
"name": "Fun InP image2video 14B",
"architecture" : "fun_inp",
"description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model).",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors"
]
}
}

View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "Fun InP image2video 1.3B",
"architecture" : "fun_inp_1.3B",
"description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_1.3B_bf16.safetensors"
]
}
}

12
defaults/hunyuan.json Normal file
View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video text2video 720p 13B",
"architecture" : "hunyuan",
"description": "Probably the best text 2 video model available.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_quanto_int8.safetensors"
]
}
}

View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video Avatar 720p 13B",
"architecture" : "hunyuan_avatar",
"description": "With the Hunyuan Video Avatar model you can animate a person based on the content of an audio input. Please note that the video generator works by processing 128 frames segment at a time (even if you ask less). The good news is that it will concatenate multiple segments for long video generation (max 3 segments recommended as the quality will get worse).",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors"
]
}
}

View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video Custom 720p 13B",
"architecture" : "hunyuan_custom",
"description": "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the moment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_quanto_bf16_int8.safetensors"
]
}
}

View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video Custom Audio 720p 13B",
"architecture" : "hunyuan_custom_audio",
"description": "The Hunyuan Video Custom Audio model can be used to generate scenes of a person speaking given a Reference Image and a Recorded Voice or Song. The reference image is not a start image and therefore one can represent the person in a different context.The video length can be anything up to 10s. It is also quite good to generate no sound Video based on a person.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors"
]
}
}

View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video Custom Edit 720p 13B",
"architecture" : "hunyuan_custom_edit",
"description": "The Hunyuan Video Custom Edit model can be used to do Video inpainting on a person (add accessories or completely replace the person). You will need in any case to define a Video Mask which will indicate which area of the Video should be edited.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_quanto_bf16_int8.safetensors"
]
}
}

12
defaults/hunyuan_i2v.json Normal file
View File

@ -0,0 +1,12 @@
{
"model":
{
"name": "Hunyuan Video image2video 720p 13B",
"architecture" : "hunyuan_i2v",
"description": "A good looking image 2 video model, but not so good in prompt adherence.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_bf16v2.safetensors",
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_quanto_int8v2.safetensors"
]
}
}

13
defaults/i2v.json Normal file
View File

@ -0,0 +1,13 @@
{
"model":
{
"name": "Wan2.1 image2video 480p 14B",
"architecture" : "i2v",
"description": "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors"
]
}
}

14
defaults/i2v_720p.json Normal file
View File

@ -0,0 +1,14 @@
{
"model":
{
"name": "Wan2.1 image2video 720p 14B",
"architecture" : "i2v",
"description": "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well).",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors"
]
},
"resolution": "1280x720"
}

View File

@ -0,0 +1,10 @@
{
"model":
{
"name": "Wan2.1 image2video 480p FusioniX 14B",
"architecture" : "i2v",
"description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": "i2v",
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"]
}
}

14
defaults/ltxv_13B.json Normal file
View File

@ -0,0 +1,14 @@
{
"model":
{
"name": "LTX Video 0.9.7 13B",
"architecture" : "ltxv_13B",
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors"
],
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
},
"num_inference_steps": 30
}

View File

@ -0,0 +1,14 @@
{
"model":
{
"name": "LTX Video 0.9.7 Distilled 13B",
"architecture" : "ltxv_13B",
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
"URLs": "ltxv_13B",
"loras": ["https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"],
"loras_multipliers": [ 1 ],
"lock_inference_steps": true,
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
},
"num_inference_steps": 6
}

View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "Phantom 1.3B",
"architecture" : "phantom_1.3B",
"description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2_1_phantom_1.3B_mbf16.safetensors"
]
}
}

13
defaults/phantom_14B.json Normal file
View File

@ -0,0 +1,13 @@
{
"model":
{
"name": "Phantom 14B",
"architecture" : "phantom_14B",
"description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors"
]
}
}

11
defaults/recam_1.3B.json Normal file
View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "ReCamMaster 1.3B",
"architecture" : "recam_1.3B",
"description": "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_recammaster_1.3B_bf16.safetensors"
]
}
}

11
defaults/sky_df_1.3B.json Normal file
View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "SkyReels2 Diffusion Forcing 1.3B",
"architecture" : "sky_df_1.3B",
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors"
]
}
}

13
defaults/sky_df_14B.json Normal file
View File

@ -0,0 +1,13 @@
{
"model":
{
"name": "SkyReels2 Diffusion Forcing 540p 14B",
"architecture" : "sky_df_14B",
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors"
]
}
}

View File

@ -0,0 +1,14 @@
{
"model":
{
"name": "SkyReels2 Diffusion Forcing 720p 14B",
"architecture" : "sky_df_14B",
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors"
]
},
"resolution": "1280x720"
}

13
defaults/t2i.json Normal file
View File

@ -0,0 +1,13 @@
{
"model": {
"name": "Wan2.1 text2image 14B",
"architecture": "t2v",
"description": "The original Wan Text 2 Video model configured to generate an image instead of a video.",
"image_outputs": true,
"URLs": "t2v"
},
"video_length": 1,
"resolution": "1280x720"
}

13
defaults/t2v.json Normal file
View File

@ -0,0 +1,13 @@
{
"model":
{
"name": "Wan2.1 text2video 14B",
"architecture" : "t2v",
"description": "The original Wan Text 2 Video model. Most other models have been built on top of it",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mfp16_int8.safetensors"
]
}
}

11
defaults/t2v_1.3B.json Normal file
View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "Wan2.1 text2video 1.3B",
"architecture" : "t2v_1.3B",
"description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_bf16.safetensors"
]
}
}

11
defaults/vace_1.3B.json Normal file
View File

@ -0,0 +1,11 @@
{
"model":
{
"name": "Vace ControlNet 1.3B",
"architecture" : "vace_1.3B",
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1.3B_mbf16.safetensors"
]
}
}

View File

@ -15,7 +15,7 @@
"seed": -1,
"num_inference_steps": 10,
"guidance_scale": 1,
"flow_shift": 5,
"flow_shift": 2,
"embedded_guidance_scale": 6,
"repeat_generation": 1,
"multi_images_gen_type": 0,

View File

@ -0,0 +1,16 @@
{
"model": {
"name": "Vace FusioniX image2image 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"image_outputs": true,
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": "t2v_fusionix"
},
"resolution": "1280x720",
"guidance_scale": 1,
"num_inference_steps": 10,
"video_length": 1
}

View File

@ -2,22 +2,30 @@
A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models.
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP, however you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP. However you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface.
WanGP finetune system can be also used to tweak default models : for instance you can add on top of an existing model some loras that will be always applied transparently.
Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV
All the finetunes definitions files should be stored in the *finetunes/* subfolder.
Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes.
## Create a new Finetune Model Definition
All the finetune models definitions are json files stored in the **finetunes** sub folder. All the corresponding finetune model weights will be stored in the *ckpts* subfolder and will sit next to the base models.
WanGP comes with a few prebuilt finetune models that you can use as starting points and to get an idea of the structure of the definition file.
## Create a new Finetune Model Definition
All the finetune models definitions are json files stored in the **finetunes/** sub folder. All the corresponding finetune model weights when they are downloaded will be stored in the *ckpts/* subfolder and will sit next to the base models.
All the models used by WanGP are also described using the finetunes json format and can be found in the **defaults/** subfolder. Please dont modify any file in the **defaults/** folder.
However you can use these files as starting points for new definition files and to get an idea of the structure of a definition file. If you want to change how a base model is handled (title, default settings, path to model weights, …) you may override any property of the default finetunes definition file by creating a new file in the finetunes folder with the same name. Everything will happen as if the two models will be merged property by property with a higher priority given to the finetunes model definition.
A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...).
You can obtain a settings file in several ways:
- In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models)
- From the user interface, go to the base model and click **export settings**
- From the user interface, select the base model for which you want to create a finetune and click **export settings**
Here are steps:
1) Create a *settings file*
@ -26,45 +34,60 @@ Here are steps:
4) Restart WanGP
## Architecture Models Ids
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are Architecture Ids:
- *t2v*: Wan 2.1 Video text 2
- *i2v*: Wan 2.1 Video image 2 480p
- *i2v_720p*: Wan 2.1 Video image 2 720p
A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are some Architecture Ids:
- *t2v*: Wan 2.1 Video text 2 video
- *i2v*: Wan 2.1 Video image 2 video 480p and 720p
- *vace_14B*: Wan 2.1 Vace 14B
- *hunyuan*: Hunyuan Video text 2 video
- *hunyuan_i2v*: Hunyuan Video image 2 video
Any file name in the defaults subfolder (without the json extension) corresponds to an architecture id.
Please note that weights of some architectures correspond to a combination of weight of a one architecture which are completed by the weights of one more or modules.
A module is a set a weights that are insufficient to be model by itself but that can be added to an existing model to extend its capabilities.
For instance if one adds a module *vace_14B* on top of a model with architecture *t2v* one gets get a model with the *vace_14B* architecture. Here *vace_14B* stands for both an architecture name and a module name. The module system allows you to reuse shared weights between models.
## The Model Subtree
- *name* : name of the finetune used to select
- *architecture* : architecture Id of the base model of the finetune (see previous section)
- *description*: description of the finetune that will appear at the top
- *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing.
- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. So far the only module supported is Vace 14B (its id is *vace_14B*). For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module.
- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module.
- *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance)
-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above.
- *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model
-*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it.
-*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame.
In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse.
For example lets say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file.
Example of **model** subtree
```
"model":
{
"name": "Wan text2video FusioniX 14B",
"architecture" : "t2v",
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
],
"model":
{
"name": "Wan text2video FusioniX 14B",
"architecture" : "t2v",
"description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
],
"preload_URLs": [
],
"auto_quantize": true
},
"auto_quantize": true
},
```
## Finetune Model Naming Convention
If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a few 32 bits weights), so *bf16* or *fp16* should appear somewhere in the name. If you need examples just look at the **ckpts** subfolder, the naming convention for the base models is the same.
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*.
Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters.
@ -82,4 +105,4 @@ If you launch the app with the *--save-quantized* switch, WanGP will create a qu
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.
Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*.

View File

@ -6,18 +6,19 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a
Loras are organized in different folders based on the model they're designed for:
### Text-to-Video Models
### Wan Text-to-Video Models
- `loras/` - General t2v loras
- `loras/1.3B/` - Loras specifically for 1.3B models
- `loras/14B/` - Loras specifically for 14B models
### Image-to-Video Models
### Wan Image-to-Video Models
- `loras_i2v/` - Image-to-video loras
### Other Models
- `loras_hunyuan/` - Hunyuan Video t2v loras
- `loras_hunyuan_i2v/` - Hunyuan Video i2v loras
- `loras_ltxv/` - LTX Video loras
- `loras_flux/` - Flux loras
## Custom Lora Directory
@ -64,7 +65,7 @@ For dynamic effects over generation steps, use comma-separated values:
## Lora Presets
Presets are combinations of loras with predefined multipliers and prompts.
Lora Presets are combinations of loras with predefined multipliers and prompts.
### Creating Presets
1. Configure your loras and multipliers
@ -95,17 +96,37 @@ WanGP supports multiple lora formats:
- **Replicate** format
- **Standard PyTorch** (.pt, .pth)
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models
## Loras Accelerators
Most Loras are used to apply a specific style or to alter the content of the output of the generated video.
However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video.
You will find most *Loras Accelerators* here:
https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators
### Setup Instructions
1. Download the Lora:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors
```
2. Place in your `loras/` directory
1. Download the Lora
2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora
## FusioniX (or FusionX) Lora
If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v
### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
2. Enable Advanced Mode
3. In Advanced Generation Tab:
- Set Guidance Scale = 1
- Set Shift Scale = 2
4. In Advanced Lora Tab:
- Select CausVid Lora
- Set multiplier to 1
5. Set generation steps from 8-10
6. Generate!
## Safe-Forcing lightx2v Lora (Video Generation Accelerator)
Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models
You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors*
### Usage
1. Select a Wan t2v or i2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
2. Enable Advanced Mode
@ -118,17 +139,10 @@ Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distil
5. Set generation steps to 2-8
6. Generate!
## CausVid Lora (Video Generation Accelerator)
CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement.
### Setup Instructions
1. Download the CausVid Lora:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors
```
2. Place in your `loras/` directory
### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
2. Enable Advanced Mode
@ -149,25 +163,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe
*Note: Lower steps = lower quality (especially motion)*
## AccVid Lora (Video Generation Accelerator)
AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1).
### Setup Instructions
1. Download the AccVid Lora:
- for t2v models:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors
```
- for i2v models:
```
https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors
```
2. Place in your `loras/` directory or `loras_i2v/` directory
### Usage
1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model
@ -268,6 +267,7 @@ In the video, a man is presented. The man is in a city and looks at his watch.
--lora-dir-hunyuan path # Path to Hunyuan t2v loras
--lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras
--lora-dir-ltxv path # Path to LTX Video loras
--lora-dir-flux path # Path to Flux loras
--lora-preset preset # Load preset on startup
--check-loras # Filter incompatible loras
```

View File

@ -2,6 +2,8 @@
WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations.
Most models can combined with Loras Accelerators (check the Lora guide) to accelerate the generation of a video x2 or x3 with little quality loss
## Wan 2.1 Text2Video Models
Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images.
@ -65,6 +67,12 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
## Wan 2.1 Specialized Models
#### Multitalk
- **Type**: Multi Talking head animation
- **Input**: Voice track + image
- **Works on**: People
- **Use case**: Lip-sync and voice-driven animation for up to two people
#### FantasySpeaking
- **Type**: Talking head animation
- **Input**: Voice track + image
@ -82,7 +90,7 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
- **Requirements**: 81+ frame input videos, 15+ denoising steps
- **Use case**: View same scene from different angles
#### Sky Reels v2
#### Sky Reels v2 Diffusion
- **Type**: Diffusion Forcing model
- **Specialty**: "Infinite length" videos
- **Features**: High quality continuous generation
@ -107,22 +115,6 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
<BR>
## Wan Special Loras
### Safe-Forcing lightx2v Lora
- **Type**: Distilled model (Lora implementation)
- **Speed**: 4-8 steps generation, 2x faster (no classifier free guidance)
- **Compatible**: Works with t2v and i2v Wan 14B models
- **Setup**: Requires Safe-Forcing lightx2v Lora (see [LORAS.md](LORAS.md))
### Causvid Lora
- **Type**: Distilled model (Lora implementation)
- **Speed**: 4-12 steps generation, 2x faster (no classifier free guidance)
- **Compatible**: Works with Wan 14B models
- **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md))
<BR>
## Hunyuan Video Models

View File

13
flux/__init__.py Normal file
View File

@ -0,0 +1,13 @@
try:
from ._version import (
version as __version__, # type: ignore
version_tuple,
)
except ImportError:
__version__ = "unknown (no version information available)"
version_tuple = (0, 0, "unknown", "noinfo")
from pathlib import Path
PACKAGE = __package__.replace("_", "-")
PACKAGE_ROOT = Path(__file__).parent

18
flux/__main__.py Normal file
View File

@ -0,0 +1,18 @@
from fire import Fire
from .cli import main as cli_main
from .cli_control import main as control_main
from .cli_fill import main as fill_main
from .cli_kontext import main as kontext_main
from .cli_redux import main as redux_main
if __name__ == "__main__":
Fire(
{
"t2i": cli_main,
"control": control_main,
"fill": fill_main,
"kontext": kontext_main,
"redux": redux_main,
}
)

21
flux/_version.py Normal file
View File

@ -0,0 +1,21 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
__version__ = version = '0.0.post58+g1371b2b'
__version_tuple__ = version_tuple = (0, 0, 'post58', 'g1371b2b')

109
flux/flux_main.py Normal file
View File

@ -0,0 +1,109 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
from mmgp import offload as offload
import torch
from wan.utils.utils import calculate_new_dimensions
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.modules.layers import get_linear_split_map
from flux.util import (
aspect_ratio_to_height_width,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
class model_factory:
def __init__(
self,
checkpoint_dir,
model_filename = None,
model_type = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
self.device = torch.device(f"cuda")
self.VAE_dtype = VAE_dtype
self.dtype = dtype
torch_device = "cpu"
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
self.clip = load_clip(torch_device)
self.name= "flux-dev-kontext"
self.model = load_flow_model(self.name, model_filename[0], torch_device)
self.vae = load_ae(self.name, device=torch_device)
# offload.change_dtype(self.model, dtype, True)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
split_linear_modules_map = get_linear_split_map()
self.model.split_linear_modules_map = split_linear_modules_map
offload.split_linear_modules(self.model, split_linear_modules_map )
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
sampling_steps: int = 20,
input_ref_images = None,
width= 832,
height=480,
guide_scale: float = 2.5,
fit_into_canvas = None,
callback = None,
loras_slists = None,
batch_size = 1,
**bbargs
):
if self._interrupt:
return None
device="cuda"
if input_ref_images != None and len(input_ref_images) > 0:
image_ref = input_ref_images[0]
w, h = image_ref.size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
inp, height, width = prepare_kontext(
t5=self.t5,
clip=self.clip,
prompt=input_prompt,
ae=self.vae,
img_cond=image_ref,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
inp.pop("img_cond_orig")
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x

54
flux/math.py Normal file
View File

@ -0,0 +1,54 @@
import torch
from einops import rearrange
from torch import Tensor
from wan.modules.attention import pay_attention
def attention(qkv_list, pe: Tensor) -> Tensor:
q, k, v = qkv_list
qkv_list.clear()
q_list = [q]
q = None
q = apply_rope_(q_list, pe)
k_list = [k]
k = None
k = apply_rope_(k_list, pe)
qkv_list = [q.transpose(1,2), k.transpose(1,2) ,v.transpose(1,2)]
del q,k, v
x = pay_attention(qkv_list).transpose(1,2)
# x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq= q_list[0]
xqshape = xq.shape
xqdtype= xq.dtype
q_list.clear()
xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq[..., 0]
xq = freqs_cis[..., 1] * xq[..., 1]
xq_out.add_(xq)
# xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
return xq_out.reshape(*xqshape).to(xqdtype)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

168
flux/model.py Normal file
View File

@ -0,0 +1,168 @@
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from flux.modules.lora import LinearLora, replace_linear_with_lora
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def preprocess_loras(self, model_type, sd):
new_sd = {}
if len(sd) == 0: return sd
first_key= next(iter(sd))
if first_key.startswith("transformer."):
src_list = [".attn.to_q.", ".attn.to_k.", ".attn.to_v."]
tgt_list = [".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v."]
for k,v in sd.items():
k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks")
k = k.replace("transformer.double_transformer_blocks", "diffusion_model.double_blocks")
for src, tgt in zip(src_list, tgt_list):
k = k.replace(src, tgt)
new_sd[k] = v
return new_sd
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
callback= None,
pipeline =None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec += self.guidance_in(timestep_embedding(guidance, 256))
vec += self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
return None
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
class FluxLoraWrapper(Flux):
def __init__(
self,
lora_rank: int = 128,
lora_scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.lora_rank = lora_rank
replace_linear_with_lora(
self,
max_rank=lora_rank,
scale=lora_scale,
)
def set_lora_scale(self, scale: float) -> None:
for module in self.modules():
if isinstance(module, LinearLora):
module.set_scale(scale=scale)

320
flux/modules/autoencoder.py Normal file
View File

@ -0,0 +1,320 @@
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian(sample=sample_z)
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def get_VAE_tile_size(*args, **kwargs):
return []
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

View File

@ -0,0 +1,38 @@
from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
import os
class HFEmbedder(nn.Module):
def __init__(self, version: str, text_encoder_filename, max_length: int, is_clip = False, **hf_kwargs):
super().__init__()
self.is_clip = is_clip
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
from mmgp import offload as offloadobj
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(os.path.dirname(text_encoder_filename), max_length=max_length)
self.hf_module: T5EncoderModel = offloadobj.fast_load_transformers_model(text_encoder_filename)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key].bfloat16()

View File

@ -0,0 +1,99 @@
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from flux.util import print_load_warning
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)
class ReduxImageEncoder(nn.Module):
siglip_model_name = "google/siglip-so400m-patch14-384"
def __init__(
self,
device,
redux_path: str,
redux_dim: int = 1152,
txt_in_features: int = 4096,
dtype=torch.bfloat16,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dtype = dtype
with self.device:
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
sd = load_sft(redux_path, device=str(device))
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
def __call__(self, x: Image.Image) -> torch.Tensor:
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
return projected_x

327
flux/modules/layers copy.py Normal file
View File

@ -0,0 +1,327 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from flux.math import attention, rope
def get_linear_split_map():
hidden_size = 3072
_modules_map = {
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
}
return split_linear_modules_map
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated.mul_(1 + img_mod1.scale)
img_modulated.add_(img_mod1.shift)
shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) )
img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2)
img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2)
img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2)
del img_modulated
# img_qkv = self.img_attn.qkv(img_modulated)
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated.mul_(1 + txt_mod1.scale)
txt_modulated.add_(txt_mod1.shift)
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) )
txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2)
txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2)
txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2)
del txt_modulated
# txt_qkv = self.txt_attn.qkv(txt_modulated)
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
qkv_list = [q, k, v]
del q, k, v
attn = attention(qkv_list, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img blocks
img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate)
img.addcmul_(self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift), img_mod2.gate)
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt blocks
txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate)
txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate)
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = self.pre_norm(x)
x_mod.mul_(1 + mod.scale)
x_mod.add_(mod.shift)
##### More spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
# x_mod = (1 + mod.scale) * x + mod.shift
shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) )
q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2)
k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2)
v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2)
# shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
# txt_q = self.linear1_attn_q(txt_mod).view(*shape)
# txt_k = self.linear1_attn_k(txt_mod).view(*shape)
# txt_v = self.linear1_attn_v(txt_mod).view(*shape)
# qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
qkv_list = [q, k, v]
del q, k, v
attn = attention(qkv_list, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
x_mod_shape = x_mod.shape
x_mod = x_mod.view(-1, x_mod.shape[-1])
chunk_size = int(x_mod_shape[1]/6)
x_chunks = torch.split(x_mod, chunk_size)
attn = attn.view(-1, attn.shape[-1])
attn_chunks =torch.split(attn, chunk_size)
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
mlp_chunk = self.linear1_mlp(x_chunk)
mlp_chunk = self.mlp_act(mlp_chunk)
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
del attn_chunk, mlp_chunk
x_chunk[...] = self.linear2(attn_mlp_chunk)
del attn_mlp_chunk
x_mod = x_mod.view(x_mod_shape)
x.addcmul_(x_mod, mod.gate)
return x
# output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
# return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

328
flux/modules/layers.py Normal file
View File

@ -0,0 +1,328 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from flux.math import attention, rope
def get_linear_split_map():
hidden_size = 3072
split_linear_modules_map = {
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
}
return split_linear_modules_map
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
if k != None:
return self.key_norm(k).to(v)
else:
return self.query_norm(q).to(v)
# q = self.query_norm(q)
# k = self.key_norm(k)
# return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
raise Exception("not implemented")
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
def split_mlp(mlp, x, divide = 4):
x_shape = x.shape
x = x.view(-1, x.shape[-1])
chunk_size = int(x_shape[1]/divide)
x_chunks = torch.split(x, chunk_size)
for i, x_chunk in enumerate(x_chunks):
mlp_chunk = mlp[0](x_chunk)
mlp_chunk = mlp[1](mlp_chunk)
x_chunk[...] = mlp[2](mlp_chunk)
return x.reshape(x_shape)
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated.mul_(1 + img_mod1.scale)
img_modulated.add_(img_mod1.shift)
shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) )
img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2)
img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2)
img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2)
del img_modulated
img_q= self.img_attn.norm(img_q, None, img_v)
img_k = self.img_attn.norm(None, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated.mul_(1 + txt_mod1.scale)
txt_modulated.add_(txt_mod1.shift)
shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) )
txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2)
txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2)
txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2)
del txt_modulated
txt_q = self.txt_attn.norm(txt_q, None, txt_v)
txt_k = self.txt_attn.norm(None, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
qkv_list = [q, k, v]
del q, k, v
attn = attention(qkv_list, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img blocks
img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate)
mod_img = self.img_norm2(img)
mod_img.mul_(1 + img_mod2.scale)
mod_img.add_(img_mod2.shift)
mod_img = split_mlp(self.img_mlp, mod_img)
# mod_img = self.img_mlp(mod_img)
img.addcmul_( mod_img, img_mod2.gate)
mod_img = None
# calculate the txt blocks
txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate)
txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = self.pre_norm(x)
x_mod.mul_(1 + mod.scale)
x_mod.add_(mod.shift)
##### More spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
# x_mod = (1 + mod.scale) * x + mod.shift
shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) )
q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2)
k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2)
v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2)
q = self.norm(q, None, v)
k = self.norm(None, k, v)
# compute attention
qkv_list = [q, k, v]
del q, k, v
attn = attention(qkv_list, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
x_mod_shape = x_mod.shape
x_mod = x_mod.view(-1, x_mod.shape[-1])
chunk_size = int(x_mod_shape[1]/6)
x_chunks = torch.split(x_mod, chunk_size)
attn = attn.view(-1, attn.shape[-1])
attn_chunks =torch.split(attn, chunk_size)
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
mlp_chunk = self.linear1_mlp(x_chunk)
mlp_chunk = self.mlp_act(mlp_chunk)
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
del attn_chunk, mlp_chunk
x_chunk[...] = self.linear2(attn_mlp_chunk)
del attn_mlp_chunk
x_mod = x_mod.view(x_mod_shape)
x.addcmul_(x_mod, mod.gate)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

94
flux/modules/lora.py Normal file
View File

@ -0,0 +1,94 @@
import torch
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
new_lora = LinearLora(
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias,
rank=max_rank,
scale=scale,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
new_lora.bias = child.bias if child.bias is not None else None
setattr(module, name, new_lora)
else:
replace_linear_with_lora(
module=child,
max_rank=max_rank,
scale=scale,
)
class LinearLora(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=device,
dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = nn.Linear(
in_features=in_features,
out_features=self.rank,
bias=False,
dtype=dtype,
device=device,
)
self.lora_B = nn.Linear(
in_features=self.rank,
out_features=out_features,
bias=self.lora_bias,
dtype=dtype,
device=device,
)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
_lora_out_B = self.lora_B(self.lora_A(input))
lora_update = _lora_out_B * self.scale
return base_out + lora_update

392
flux/sampling.py Normal file
View File

@ -0,0 +1,392 @@
import math
from typing import Callable
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from torch import Tensor
from .model import Flux
from .modules.autoencoder import AutoEncoder
from .modules.conditioner import HFEmbedder
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
from .util import PREFERED_KONTEXT_RESOLUTIONS
from einops import rearrange, repeat
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def prepare_control(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
ae: AutoEncoder,
encoder: DepthImageEncoder | CannyImageEncoder,
img_cond_path: str,
) -> dict[str, Tensor]:
# load and encode the conditioning image
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond = Image.open(img_cond_path).convert("RGB")
width = w * 8
height = h * 8
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
with torch.no_grad():
img_cond = encoder(img_cond)
img_cond = ae.encode(img_cond)
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond"] = img_cond
return return_dict
def prepare_fill(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
ae: AutoEncoder,
img_cond_path: str,
mask_path: str,
) -> dict[str, Tensor]:
# load and encode the conditioning image and the mask
bs, _, _, _ = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond = Image.open(img_cond_path).convert("RGB")
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
mask = Image.open(mask_path).convert("L")
mask = np.array(mask)
mask = torch.from_numpy(mask).float() / 255.0
mask = rearrange(mask, "h w -> 1 1 h w")
with torch.no_grad():
img_cond = img_cond.to(img.device)
mask = mask.to(img.device)
img_cond = img_cond * (1 - mask)
img_cond = ae.encode(img_cond)
mask = mask[:, 0, :, :]
mask = mask.to(torch.bfloat16)
mask = rearrange(
mask,
"b (h ph) (w pw) -> b (ph pw) h w",
ph=8,
pw=8,
)
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if mask.shape[0] == 1 and bs > 1:
mask = repeat(mask, "1 ... -> bs ...", bs=bs)
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
img_cond = torch.cat((img_cond, mask), dim=-1)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond"] = img_cond.to(img.device)
return return_dict
def prepare_redux(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
encoder: ReduxImageEncoder,
img_cond_path: str,
) -> dict[str, Tensor]:
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond = Image.open(img_cond_path).convert("RGB")
with torch.no_grad():
img_cond = encoder(img_cond)
img_cond = img_cond.to(torch.bfloat16)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def prepare_kontext(
t5: HFEmbedder,
clip: HFEmbedder,
prompt: str | list[str],
ae: AutoEncoder,
img_cond: str,
seed: int,
device: torch.device,
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
) -> tuple[dict[str, Tensor], int, int]:
# load and encode the conditioning image
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
width, height = img_cond.size
aspect_ratio = width / height
# Kontext is trained on specific resolutions, using one of them is recommended
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
width = 2 * int(width / 16)
height = 2 * int(height / 16)
img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
img_cond_orig = img_cond.clone()
with torch.no_grad():
img_cond = ae.encode(img_cond.to(device))
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
# image ids are the same as base image with the first dimension set to 1
# instead of 0
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
img_cond_ids[..., 0] = 1
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
if target_width is None:
target_width = 8 * width
if target_height is None:
target_height = 8 * height
img = get_noise(
bs,
target_height,
target_width,
device=device,
dtype=torch.bfloat16,
seed=seed,
)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond_seq"] = img_cond
return_dict["img_cond_seq_ids"] = img_cond_ids.to(device)
return_dict["img_cond_orig"] = img_cond_orig
return return_dict, target_height, target_width
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
# extra img tokens (channel-wise)
img_cond: Tensor | None = None,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
callback=None,
pipeline=None,
loras_slists=None,
unpack_latent = None,
):
kwargs = {'pipeline': pipeline, 'callback': callback}
if callback != None:
callback(-1, None, True)
updated_num_steps= len(timesteps) -1
if callback != None:
from wgp import update_loras_slists
update_loras_slists(model, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
from mmgp import offload
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
offload.set_step_no_for_lora(model, i)
if pipeline._interrupt:
return None
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
if img_cond is not None:
img_input = torch.cat((img, img_cond), dim=-1)
if img_cond_seq is not None:
assert (
img_cond_seq_ids is not None
), "You need to provide either both or neither of the sequence conditioning"
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
img=img_input,
img_ids=img_input_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
)
if pred == None: return None
if img_input_ids is not None:
pred = pred[:, : img.shape[1]]
img += (t_prev - t_curr) * pred
if callback is not None:
preview = unpack_latent(img).transpose(0,1)
callback(i, preview, False)
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)

302
flux/to_remove/cli.py Normal file
View File

@ -0,0 +1,302 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
check_onnx_access_for_trt,
configs,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
NSFW_THRESHOLD = 0.85
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
@torch.inference_mode()
def main(
name: str = "flux-schnell",
width: int = 1360,
height: int = 768,
seed: int | None = None,
prompt: str = (
"a photo of a forest with mist swirling around the tree trunks. The word "
'"FLUX" is painted over it in big, red brush strokes with visible texture'
),
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the model to load
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
track_usage: track usage of the model for licensing purposes
"""
prompt = prompt.split("|")
if len(prompt) == 1:
prompt = prompt[0]
additional_prompts = None
else:
additional_prompts = prompt[1:]
prompt = prompt[0]
assert not (
(additional_prompts is not None) and loop
), "Do not provide additional prompts and set loop to True"
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 50
# allow for packing and conversion to latent space
height = 16 * (height // 16)
width = 16 * (width // 16)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if not trt:
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
# Check if we need ONNX model access (which requires authentication for FLUX models)
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
ModuleName.VAE,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
trt_static_batch=False,
trt_static_shape=False,
)
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if loop:
opts = parse_prompt(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
fn = output_name.format(idx=idx)
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
elif additional_prompts:
next_prompt = additional_prompts.pop(0)
opts.prompt = next_prompt
else:
opts = None
if trt:
trt_ctx_manager.stop_runtime()
if __name__ == "__main__":
Fire(main)

View File

@ -0,0 +1,390 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
lora_scale: float | None
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
changed = False
if options is None:
return None, changed
user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the lora scale or write a command starting with a slash:\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/q"):
print("Quitting")
return None, changed
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.lora_scale = float(prompt)
changed = True
return options, changed
@torch.inference_mode()
def main(
name: str,
width: int = 1024,
height: int = 1024,
seed: int | None = None,
prompt: str = "a robot made out of gold",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 50,
loop: bool = False,
guidance: float | None = None,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
lora_scale: float | None = 0.85,
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if "lora" in name:
assert not trt, "TRT does not support LORA"
assert name in [
"flux-dev-canny",
"flux-dev-depth",
"flux-dev-canny-lora",
"flux-dev-depth-lora",
], f"Got unknown model name: {name}"
if guidance is None:
if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
guidance = 30.0
elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
guidance = 10.0
else:
raise NotImplementedError()
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
img_embedder = DepthImageEncoder(torch_device)
elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
img_embedder = CannyImageEncoder(torch_device)
else:
raise NotImplementedError()
if not trt:
# init all components
t5 = load_t5(torch_device, max_length=512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
ModuleName.VAE,
ModuleName.VAE_ENCODER,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_static_batch=kwargs.get("static_batch", True),
trt_static_shape=kwargs.get("static_shape", True),
)
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
# set lora scale
if "lora" in name and lora_scale is not None:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(lora_scale)
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
lora_scale=lora_scale,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
if "lora" in name:
opts, changed = parse_lora_scale(opts)
if changed:
# update the lora scale:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(opts.lora_scale)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp = prepare_control(
t5,
clip,
x,
prompt=opts.prompt,
ae=ae,
encoder=img_embedder,
img_cond_path=opts.img_cond_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
if "lora" in name:
opts, changed = parse_lora_scale(opts)
if changed:
# update the lora scale:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(opts.lora_scale)
else:
opts = None
if trt:
trt_ctx_manager.stop_runtime()
if __name__ == "__main__":
Fire(main)

334
flux/to_remove/cli_fill.py Normal file
View File

@ -0,0 +1,334 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from PIL import Image
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
img_mask_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
else:
with Image.open(img_cond_path) as img:
width, height = img.size
if width % 32 != 0 or height % 32 != 0:
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
continue
options.img_cond_path = img_cond_path
break
return options
def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning mask or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_mask_path = input(user_question)
if img_mask_path.startswith("/"):
if img_mask_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_mask_path.startswith("/h"):
print(f"Got invalid command '{img_mask_path}'\n{usage}")
print(usage)
continue
if img_mask_path == "":
break
if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_mask_path}' does not exist or is not a valid image file")
continue
else:
with Image.open(img_mask_path) as img:
width, height = img.size
if width % 32 != 0 or height % 32 != 0:
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
continue
else:
with Image.open(options.img_cond_path) as img_cond:
img_cond_width, img_cond_height = img_cond.size
if width != img_cond_width or height != img_cond_height:
print(
f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
)
continue
options.img_mask_path = img_mask_path
break
return options
@torch.inference_mode()
def main(
seed: int | None = None,
prompt: str = "a white paper cup",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 50,
loop: bool = False,
guidance: float = 30.0,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
img_mask_path: str = "assets/cup_mask.png",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image. This demo assumes that the conditioning image and mask have
the same shape and that height and width are divisible by 32.
Args:
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
img_mask_path: path to conditioning mask (jpeg/png/webp)
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
name = "flux-dev-fill"
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
# init all components
t5 = load_t5(torch_device, max_length=128)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
rng = torch.Generator(device="cpu")
with Image.open(img_cond_path) as img:
width, height = img.size
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
img_mask_path=img_mask_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
with Image.open(opts.img_cond_path) as img:
width, height = img.size
opts.height = height
opts.width = width
opts = parse_img_mask_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp = prepare_fill(
t5,
clip,
x,
prompt=opts.prompt,
ae=ae,
img_cond_path=opts.img_cond_path,
mask_path=opts.img_mask_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
with Image.open(opts.img_cond_path) as img:
width, height = img.size
opts.height = height
opts.width = width
opts = parse_img_mask_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)

View File

@ -0,0 +1,368 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from flux.content_filters import PixtralContentFilter
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.util import (
aspect_ratio_to_height_width,
check_onnx_access_for_trt,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
@dataclass
class SamplingOptions:
prompt: str
width: int | None
height: int | None
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/ar <width>:<height>' will set the aspect ratio of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/ar"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, ratio_prompt = prompt.split()
if ratio_prompt == "auto":
options.width = None
options.height = None
print("Setting resolution to input image resolution.")
else:
options.width, options.height = aspect_ratio_to_height_width(ratio_prompt)
print(f"Setting resolution to {options.width} x {options.height}.")
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
if height == "auto":
options.height = None
else:
options.height = 16 * (int(height) // 16)
if options.height is not None and options.width is not None:
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
else:
print(f"Setting resolution to {options.width} x {options.height}.")
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write a path to an image directly, leave this field empty "
"to repeat the last input image or write a command starting with a slash:\n"
"- '/q' to quit\n\n"
"The input image will be edited by FLUX.1 Kontext creating a new image based"
"on your instruction prompt."
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
@torch.inference_mode()
def main(
name: str = "flux-dev-kontext",
aspect_ratio: str | None = None,
seed: int | None = None,
prompt: str = "replace the logo with the text 'Black Forest Labs'",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 30,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
height: height of the sample in pixels (should be a multiple of 16), None
defaults to the size of the conditioning
width: width of the sample in pixels (should be a multiple of 16), None
defaults to the size of the conditioning
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
track_usage: track usage of the model for licensing purposes
"""
assert name == "flux-dev-kontext", f"Got unknown model name: {name}"
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if aspect_ratio is None:
width = None
height = None
else:
width, height = aspect_ratio_to_height_width(aspect_ratio)
if not trt:
t5 = load_t5(torch_device, max_length=512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
# Check if we need ONNX model access (which requires authentication for FLUX models)
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
trt_static_batch=False,
trt_static_shape=False,
)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
content_filter = PixtralContentFilter(torch.device("cpu"))
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
if content_filter.test_txt(opts.prompt):
print("Your prompt has been automatically flagged. Please choose another prompt.")
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if content_filter.test_image(opts.img_cond_path):
print("Your input image has been automatically flagged. Please choose another image.")
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp, height, width = prepare_kontext(
t5=t5,
clip=clip,
prompt=opts.prompt,
ae=ae,
img_cond_path=opts.img_cond_path,
target_width=opts.width,
target_height=opts.height,
bs=1,
seed=opts.seed,
device=torch_device,
)
from safetensors.torch import save_file
save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft")
inp.pop("img_cond_orig")
opts.seed = None
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
t00 = time.time()
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
torch.cuda.synchronize()
t01 = time.time()
print(f"Denoising took {t01 - t00:.3f}s")
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), height, width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
ae_dev_t0 = time.perf_counter()
x = ae.decode(x)
torch.cuda.synchronize()
ae_dev_t1 = time.perf_counter()
print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s")
if content_filter.test_image(x.cpu()):
print(
"Your output image has been automatically flagged. Choose another prompt/image or try again."
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)

290
flux/to_remove/cli_redux.py Normal file
View File

@ -0,0 +1,290 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.modules.image_embedders import ReduxImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
from flux.util import (
get_checkpoint_path,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Leave this field empty to do nothing "
"or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
@torch.inference_mode()
def main(
name: str = "flux-dev",
width: int = 1360,
height: int = 768,
seed: int | None = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the base model to use (either 'flux-dev' or 'flux-schnell')
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
offload: offload models to CPU when not in use
output_dir: where to save the output images
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if name not in (available := ["flux-dev", "flux-schnell"]):
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 50
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
# init all components
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
# Download and initialize the Redux adapter
redux_path = str(
get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX")
)
img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path)
rng = torch.Generator(device="cpu")
prompt = ""
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare_redux(
t5,
clip,
x,
prompt=opts.prompt,
encoder=img_embedder,
img_cond_path=opts.img_cond_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)

View File

@ -0,0 +1,171 @@
import torch
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline
PROMPT_IMAGE_INTEGRITY = """
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
Output: Respond with only "yes" or "no"
Criteria for "yes":
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
- The image displays a trademarked logo or brand
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
Criteria for "no":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person depicted
2. General references to demographics or characteristics are not sufficient
3. Base your decision solely on visual content, not interpretation
4. Provide only the one-word answer: "yes" or "no"
""".strip()
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
PROMPT_TEXT_INTEGRITY = """
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
Output: Respond with only "yes" or "no"
Criteria for "Yes":
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
- The prompt explicitly mentions a trademarked logo or brand
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
Criteria for "No":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person referenced
2. General demographic descriptions or characteristics are not sufficient
3. Analyze only the prompt text, not potential image outcomes
4. Provide only the one-word answer: "yes" or "no"
The prompt to check is:
-----
{prompt}
-----
Does this prompt have copyright concerns or includes public figures?
""".strip()
class PixtralContentFilter(torch.nn.Module):
def __init__(
self,
device: torch.device = torch.device("cpu"),
nsfw_threshold: float = 0.85,
):
super().__init__()
model_id = "mistral-community/pixtral-12b"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device)
self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"])
self.nsfw_classifier = pipeline(
"image-classification", model="Falconsai/nsfw_image_detection", device=device
)
self.nsfw_threshold = nsfw_threshold
def yes_no_logit_processor(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Sets all tokens but yes/no to the minimum.
"""
scores_yes_token = scores[:, self.yes_token].clone()
scores_no_token = scores[:, self.no_token].clone()
scores_min = scores.min()
scores[:, :] = scores_min - 1
scores[:, self.yes_token] = scores_yes_token
scores[:, self.no_token] = scores_no_token
return scores
def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
if isinstance(image, torch.Tensor):
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
elif isinstance(image, str):
image = Image.open(image)
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
if classification["score"] > self.nsfw_threshold:
return True
# 512^2 pixels are enough for checking
w, h = image.size
f = (512**2 / (w * h)) ** 0.5
image = image.resize((int(f * w), int(f * h)))
chat = [
{
"role": "user",
"content": [
{
"type": "text",
"content": PROMPT_IMAGE_INTEGRITY,
},
{
"type": "image",
"image": image,
},
{
"type": "text",
"content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
},
],
}
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
def test_txt(self, txt: str) -> bool:
chat = [
{
"role": "user",
"content": [
{
"type": "text",
"content": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
},
],
}
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token

702
flux/util.py Normal file
View File

@ -0,0 +1,702 @@
import getpass
import math
import os
from dataclasses import dataclass
from pathlib import Path
import requests
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download, login
from PIL import ExifTags, Image
from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxLoraWrapper, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder
CHECKPOINTS_DIR = Path("checkpoints")
CHECKPOINTS_DIR.mkdir(exist_ok=True)
BFL_API_KEY = os.getenv("BFL_API_KEY")
os.environ.setdefault("TRT_ENGINE_DIR", str(CHECKPOINTS_DIR / "trt_engines"))
(CHECKPOINTS_DIR / "trt_engines").mkdir(exist_ok=True)
def ensure_hf_auth():
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.")
try:
login(token=hf_token)
print("Successfully authenticated with HuggingFace using HF_TOKEN")
return True
except Exception as e:
print(f"Warning: Failed to authenticate with HF_TOKEN: {e}")
if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")):
print("Already authenticated with HuggingFace")
return True
return False
def prompt_for_hf_auth():
try:
token = getpass.getpass("HF Token (hidden input): ").strip()
if not token:
print("No token provided. Aborting.")
return False
login(token=token)
print("Successfully authenticated!")
return True
except KeyboardInterrupt:
print("\nAuthentication cancelled by user.")
return False
except Exception as auth_e:
print(f"Authentication failed: {auth_e}")
print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
return False
def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path:
"""Get the local path for a checkpoint file, downloading if necessary."""
# if os.environ.get(env_var) is not None:
# local_path = os.environ[env_var]
# if os.path.exists(local_path):
# return Path(local_path)
# print(
# f"Trying to load model {repo_id}, {filename} from environment "
# f"variable {env_var}. But file {local_path} does not exist. "
# "Falling back to default location."
# )
# # Create a safe directory name from repo_id
# safe_repo_name = repo_id.replace("/", "_")
# checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name
# checkpoint_dir.mkdir(exist_ok=True)
# local_path = checkpoint_dir / filename
local_path = filename
from mmgp import offload
if False:
print(f"Downloading {filename} from {repo_id} to {local_path}")
try:
ensure_hf_auth()
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
except Exception as e:
if "gated repo" in str(e).lower() or "restricted" in str(e).lower():
print(f"\nError: Cannot access {repo_id} -- this is a gated repository.")
# Try one more time to authenticate
if prompt_for_hf_auth():
# Retry the download after authentication
print("Retrying download...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
else:
print("Authentication failed or cancelled.")
print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
raise RuntimeError(f"Authentication required for {repo_id}")
else:
raise e
return local_path
def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
"""Download ONNX models for TRT to our checkpoints directory"""
onnx_repo_map = {
"flux-dev": "black-forest-labs/FLUX.1-dev-onnx",
"flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx",
"flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx",
"flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx",
"flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx",
"flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx",
"flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx",
}
if model_name not in onnx_repo_map:
return None # No ONNX repository required for this model
repo_id = onnx_repo_map[model_name]
safe_repo_name = repo_id.replace("/", "_")
onnx_dir = CHECKPOINTS_DIR / safe_repo_name
# Map of module names to their ONNX file paths (using specified precision)
onnx_file_map = {
"clip": "clip.opt/model.onnx",
"transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx",
"transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data",
"t5": "t5.opt/model.onnx",
"t5_data": "t5.opt/backbone.onnx_data",
"vae": "vae.opt/model.onnx",
}
# If all files exist locally, return the custom_onnx_paths format
if onnx_dir.exists():
all_files_exist = True
custom_paths = []
for module, onnx_file in onnx_file_map.items():
if module.endswith("_data"):
continue # Skip data files
local_path = onnx_dir / onnx_file
if not local_path.exists():
all_files_exist = False
break
custom_paths.append(f"{module}:{local_path}")
if all_files_exist:
print(f"ONNX models ready in {onnx_dir}")
return ",".join(custom_paths)
# If not all files exist, download them
print(f"Downloading ONNX models from {repo_id} to {onnx_dir}")
print(f"Using transformer precision: {trt_transformer_precision}")
onnx_dir.mkdir(exist_ok=True)
# Download all ONNX files
for module, onnx_file in onnx_file_map.items():
local_path = onnx_dir / onnx_file
if local_path.exists():
continue # Already downloaded
# Create parent directories
local_path.parent.mkdir(parents=True, exist_ok=True)
try:
print(f"Downloading {onnx_file}")
hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir)
except Exception as e:
if "does not exist" in str(e).lower() or "not found" in str(e).lower():
continue
elif "gated repo" in str(e).lower() or "restricted" in str(e).lower():
print(f"Cannot access {repo_id} - requires license acceptance")
print("Please follow these steps:")
print(f" 1. Visit: https://huggingface.co/{repo_id}")
print(" 2. Log in to your HuggingFace account")
print(" 3. Accept the license terms and conditions")
print(" 4. Then retry this command")
raise RuntimeError(f"License acceptance required for {model_name}")
else:
# Re-raise other errors
raise
print(f"ONNX models ready in {onnx_dir}")
# Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2"
# Note: Only return the actual module paths, not the data file
custom_paths = []
for module, onnx_file in onnx_file_map.items():
if module.endswith("_data"):
continue # Skip the data file in the return paths
full_path = onnx_dir / onnx_file
if full_path.exists():
custom_paths.append(f"{module}:{full_path}")
return ",".join(custom_paths)
def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
"""Check ONNX access and download models for TRT - returns ONNX directory path"""
return download_onnx_models_for_trt(model_name, trt_transformer_precision)
def track_usage_via_api(name: str, n=1) -> None:
"""
Track usage of licensed models via the BFL API for commercial licensing compliance.
For more information on licensing BFL's models for commercial use and usage reporting,
see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true
"""
assert BFL_API_KEY is not None, "BFL_API_KEY is not set"
model_slug_map = {
"flux-dev": "flux-1-dev",
"flux-dev-kontext": "flux-1-kontext-dev",
"flux-dev-fill": "flux-tools",
"flux-dev-depth": "flux-tools",
"flux-dev-canny": "flux-tools",
"flux-dev-canny-lora": "flux-tools",
"flux-dev-depth-lora": "flux-tools",
"flux-dev-redux": "flux-tools",
}
if name not in model_slug_map:
print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.")
return
model_slug = model_slug_map[name]
url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage"
headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"}
payload = {"number_of_generations": n}
response = requests.post(url, headers=headers, json=payload)
if response.status_code != 200:
raise Exception(f"Failed to track usage: {response.status_code} {response.text}")
else:
print(f"Successfully tracked usage for {name} with {n} generations")
def save_image(
nsfw_classifier,
name: str,
output_name: str,
idx: int,
x: torch.Tensor,
add_sampling_metadata: bool,
prompt: str,
nsfw_threshold: float = 0.85,
track_usage: bool = False,
) -> int:
fn = output_name.format(idx=idx)
print(f"Saving {fn}")
# bring into PIL format and save
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
if nsfw_classifier is not None:
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
else:
nsfw_score = nsfw_threshold - 1.0
if nsfw_score < nsfw_threshold:
exif_data = Image.Exif()
if name in ["flux-dev", "flux-schnell"]:
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
else:
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(fn, exif=exif_data, quality=95, subsampling=0)
if track_usage:
track_usage_via_api(name, 1)
idx += 1
else:
print("Your generated image may contain NSFW content.")
return idx
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
repo_id: str
repo_flow: str
repo_ae: str
lora_repo_id: str | None = None
lora_filename: str | None = None
configs = {
"flux-dev": ModelSpec(
repo_id="",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-canny": ModelSpec(
repo_id="black-forest-labs/FLUX.1-Canny-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=128,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-canny-lora": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora",
lora_filename="flux1-canny-dev-lora.safetensors",
params=FluxParams(
in_channels=128,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-depth": ModelSpec(
repo_id="black-forest-labs/FLUX.1-Depth-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=128,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-depth-lora": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora",
lora_filename="flux1-depth-dev-lora.safetensors",
params=FluxParams(
in_channels=128,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-redux": ModelSpec(
repo_id="black-forest-labs/FLUX.1-Redux-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-fill": ModelSpec(
repo_id="black-forest-labs/FLUX.1-Fill-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=384,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-kontext": ModelSpec(
repo_id="black-forest-labs/FLUX.1-Kontext-dev",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]:
width = float(aspect_ratio.split(":")[0])
height = float(aspect_ratio.split(":")[1])
ratio = width / height
width = round(math.sqrt(area * ratio))
height = round(math.sqrt(area / ratio))
return 16 * (width // 16), 16 * (height // 16)
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux:
# Loading Flux
config = configs[name]
ckpt_path = model_filename #config.repo_flow
with torch.device("meta"):
if config.lora_repo_id is not None and config.lora_filename is not None:
model = FluxLoraWrapper(params=config.params).to(torch.bfloat16)
else:
model = Flux(config.params).to(torch.bfloat16)
# print(f"Loading checkpoint: {ckpt_path}")
from mmgp import offload
offload.load_model_data(model, model_filename )
# # load_sft doesn't support torch.device
# sd = load_sft(ckpt_path, device=str(device))
# sd = optionally_expand_state_dict(model, sd)
# missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
# if verbose:
# print_load_warning(missing, unexpected)
# if config.lora_repo_id is not None and config.lora_filename is not None:
# print("Loading LoRA")
# lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA"))
# lora_sd = load_sft(lora_path, device=str(device))
# # loading the lora params + overwriting scale values in the norms
# missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
# if verbose:
# print_load_warning(missing, unexpected)
return model
def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device)
def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
config = configs[name]
ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE"))
# Loading the autoencoder
with torch.device("meta"):
ae = AutoEncoder(config.ae_params)
# print(f"Loading AE checkpoint: {ckpt_path}")
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae
def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
"""
Optionally expand the state dict to match the model's parameters shapes.
"""
for name, param in model.named_parameters():
if name in state_dict:
if state_dict[name].shape != param.shape:
print(
f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
)
# expand with zeros:
expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
expanded_state_dict_weight[slices] = state_dict[name]
state_dict[name] = expanded_state_dict_weight
return state_dict

1
loras_flux/readme.txt Normal file
View File

@ -0,0 +1 @@
flux loras go here

View File

@ -149,6 +149,7 @@ class LTXV:
self,
model_filepath: str,
text_encoder_filepath: str,
model_def,
dtype = torch.bfloat16,
VAE_dtype = torch.bfloat16,
mixed_precision_transformer = False
@ -157,8 +158,8 @@ class LTXV:
# if dtype == torch.float16:
dtype = torch.bfloat16
self.mixed_precision_transformer = mixed_precision_transformer
self.distilled = any("lora" in name for name in model_filepath)
model_filepath = [name for name in model_filepath if not "lora" in name ]
self.model_def = model_def
self.pipeline_config = model_def["LTXV_config"]
# with safe_open(ckpt_path, framework="pt") as f:
# metadata = f.metadata()
# config_str = metadata.get("config")
@ -220,11 +221,11 @@ class LTXV:
prompt_enhancer_llm_model = None
prompt_enhancer_llm_tokenizer = None
if prompt_enhancer_image_caption_model != None:
pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model
prompt_enhancer_image_caption_model._model_dtype = torch.float
# if prompt_enhancer_image_caption_model != None:
# pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model
# prompt_enhancer_image_caption_model._model_dtype = torch.float
pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model
# pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model
# offload.profile(pipe, profile_no=5, extraModelsToQuantize = None, quantizeTransformer = False, budgets = { "prompt_enhancer_llm_model" : 10000, "prompt_enhancer_image_caption_model" : 10000, "vae" : 3000, "*" : 100 }, verboseLevel=2)
@ -299,14 +300,10 @@ class LTXV:
conditioning_media_paths = None
conditioning_start_frames = None
if self.distilled :
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
else:
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
# check if pipeline_config is a file
if not os.path.isfile(pipeline_config):
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
with open(pipeline_config, "r") as f:
if not os.path.isfile(self.pipeline_config):
raise ValueError(f"Pipeline config file {self.pipeline_config} does not exist")
with open(self.pipeline_config, "r") as f:
pipeline_config = yaml.safe_load(f)
@ -520,7 +517,7 @@ def get_media_num_frames(media_path: str) -> int:
return media_path.shape[1]
elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]):
reader = imageio.get_reader(media_path)
return min(reader.count_frames(), max_frames)
return min(reader.count_frames(), 0) # to do
else:
raise Exception("video format not supported")
@ -564,6 +561,3 @@ def load_media_file(
raise Exception("video format not supported")
return media_tensor
if __name__ == "__main__":
main()

View File

@ -0,0 +1,21 @@
# Thanks to https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/film_grain.py
import torch
def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation: float = 0.5):
device = images.device
images = images.permute(1, 2 ,3 ,0)
images.add_(1.).div_(2.)
grain = torch.randn_like(images, device=device)
grain[:, :, :, 0] *= 2
grain[:, :, :, 2] *= 3
grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat(
1, 1, 1, 3
) * (1 - saturation)
# Blend the grain with the image
noised_images = images + grain_intensity * grain
noised_images.clamp_(0, 1)
noised_images.sub_(.5).mul_(2.)
noised_images = noised_images.permute(3, 0, 1 ,2)
return noised_images

View File

@ -15,13 +15,7 @@ from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
XFORMERS_AVAILABLE = False
class Attention(nn.Module):

View File

@ -23,14 +23,7 @@ from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
# logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
XFORMERS_AVAILABLE = False
class Block(nn.Module):

View File

@ -65,6 +65,7 @@ def get_frames_from_image(image_input, image_state):
Return
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
load_sam()
user_name = time.time()
frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
@ -89,7 +90,7 @@ def get_frames_from_image(image_input, image_state):
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=False), \
gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=True), \
gr.update(visible=True)
@ -103,6 +104,8 @@ def get_frames_from_video(video_input, video_state):
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
load_sam()
while model == None:
time.sleep(1)
@ -273,6 +276,20 @@ def save_video(frames, output_path, fps):
return output_path
def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255)
xmin = min(cols)
xmax = max(cols) + 1
ymin = min(rows)
ymax = max(rows) + 1
xmin = max(xmin, 0)
ymin = max(ymin, 0)
xmax = min(xmax, mask.shape[1])
ymax = min(ymax, mask.shape[0])
box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box]
return box
# image matting
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
@ -320,9 +337,17 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
foreground = output_frames
foreground_output = Image.fromarray(foreground[-1])
alpha_output = Image.fromarray(alpha[-1][:,:,0])
return foreground_output, gr.update(visible=True)
alpha_output = alpha[-1][:,:,0]
frame_temp = alpha_output.copy()
alpha_output[frame_temp > 127] = 0
alpha_output[frame_temp <= 127] = 255
bbox_info = mask_to_xyxy_box(alpha_output)
h = alpha_output.shape[0]
w = alpha_output.shape[1]
bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ]
bbox_info = ":".join(bbox_info)
alpha_output = Image.fromarray(alpha_output)
return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True)
# video matting
def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
@ -469,6 +494,13 @@ def restart():
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
def load_sam():
global model_loaded
global model
global matanyone_model
model.samcontroler.sam_controler.model.to(arg_device)
matanyone_model.to(arg_device)
def load_unload_models(selected):
global model_loaded
global model
@ -476,8 +508,7 @@ def load_unload_models(selected):
if selected:
# print("Matanyone Tab Selected")
if model_loaded:
model.samcontroler.sam_controler.model.to(arg_device)
matanyone_model.to(arg_device)
load_sam()
else:
# args, defined in track_anything.py
sam_checkpoint_url_dict = {
@ -522,12 +553,16 @@ def export_to_vace_video_input(foreground_video_output):
def export_image(image_refs, image_output):
gr.Info("Masked Image transferred to Current Video")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
if image_refs == None:
image_refs =[]
image_refs.append( image_output)
return image_refs
def export_image_mask(image_input, image_mask):
gr.Info("Input Image & Mask transferred to Current Video")
return Image.fromarray(image_input), image_mask
def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
gr.Info("Original Video and Full Mask have been transferred")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
@ -543,7 +578,7 @@ def teleport_to_video_tab(tab_state):
return gr.Tabs(selected="video_gen")
def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, vace_image_refs):
def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
@ -677,7 +712,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
export_image_mask_btn = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
with gr.Row():
with gr.Row(visible= False):
export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False)
@ -696,7 +731,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
],
outputs=[video_state, video_info, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
foreground_video_output, alpha_video_output, foreground_output_button, export_image_mask_btn, mask_dropdown, step2_title]
)
# second step: select images from slider
@ -755,7 +790,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
],
queue=False,
show_progress=False)
@ -770,7 +805,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title
],
queue=False,
show_progress=False)
@ -872,15 +907,19 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
# output image
with gr.Row(equal_height=True):
foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image")
alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image")
with gr.Row(equal_height=True):
bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False)
with gr.Row():
with gr.Row():
export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button")
with gr.Column(scale=2, visible= False):
alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
# with gr.Row():
export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button")
# with gr.Column(scale=2, visible= True):
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
# first step: get the image information
extract_frames_button.click(
@ -890,9 +929,17 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
],
outputs=[image_state, image_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_image_output, alpha_image_output, export_image_btn, alpha_output_button, mask_dropdown, step2_title]
foreground_image_output, alpha_image_output, bbox_info, export_image_btn, export_image_mask_btn, mask_dropdown, step2_title]
)
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [image_state, click_state,],
outputs = [template_frame,click_state],
)
# second step: select images from slider
image_selection_slider.release(fn=select_image_template,
inputs=[image_selection_slider, image_state, interactive_state],
@ -925,7 +972,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va
matting_button.click(
fn=image_matting,
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
outputs=[foreground_image_output, export_image_btn]
outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
)

View File

@ -61,6 +61,7 @@ class WanAny2V:
checkpoint_dir,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
@ -75,7 +76,8 @@ class WanAny2V:
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.model_def = model_def
self.image_outputs = model_def.get("image_outputs", False)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
@ -106,18 +108,18 @@ class WanAny2V:
# config = json.load(f)
# from mmgp import safetensors2
# sd = safetensors2.torch_load_file(xmodel_filename)
# model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors"
base_config_file = f"configs/{base_model_type}.json"
forcedConfigPath = base_config_file if len(model_filename) > 1 or base_model_type in ["flf2v_720p"] else None
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
# model_filename[1] = xmodel_filename
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
# self.model = offload.load_model_data(self.model, xmodel_filename )
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
# self.model.to(torch.bfloat16)
# self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "multitalkbf16.safetensors", config_file_path=base_config_file, filter_sd=sd)
# offload.save_model(self.model, "flf2v_720p.safetensors", config_file_path=base_config_file)
# offload.save_model(self.model, "flf2v_quanto_int8_fp16_720p.safetensors", do_quantize= True, config_file_path=base_config_file)
# offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd)
# offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file)
@ -126,7 +128,7 @@ class WanAny2V:
self.model.eval().requires_grad_(False)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt
@ -208,7 +210,7 @@ class WanAny2V:
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask_pad = torch.zeros(mask.shape[0], length, *mask.shape[-2:], dtype=mask.dtype, device=mask.device)
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
@ -327,20 +329,6 @@ class WanAny2V:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
# else:
# assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = []
for ref_image in ref_images:
@ -366,6 +354,7 @@ class WanAny2V:
height = 720,
fit_into_canvas = True,
frame_num=81,
batch_size = 1,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
@ -397,6 +386,7 @@ class WanAny2V:
NAG_alpha = 0.5,
offloadobj = None,
apg_switch = False,
speakers_bboxes = None,
**bbargs
):
@ -477,8 +467,8 @@ class WanAny2V:
any_end_frame = False
if input_frames != None:
_ , preframes_count, height, width = input_frames.shape
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
clip_context = self.clip.visual([input_frames[:, -1:]]) #.to(self.param_dtype)
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
clip_context = self.clip.visual([input_frames[:, -1:]]) if model_type != "flf2v_720p" else self.clip.visual([input_frames[:, -1:], input_frames[:, -1:]])
input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype)
enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width),
device=self.device, dtype= self.VAE_dtype)],
@ -488,7 +478,7 @@ class WanAny2V:
preframes_count = 1
image_start = TF.to_tensor(image_start)
any_end_frame = image_end != None
add_frames_for_end_image = any_end_frame and model_type not in ["fun_inp_1.3B", "fun_inp", "i2v_720p"]
add_frames_for_end_image = any_end_frame and model_type == "i2v"
if any_end_frame:
image_end = TF.to_tensor(image_end)
if add_frames_for_end_image:
@ -517,8 +507,8 @@ class WanAny2V:
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
if image_end != None and model_type == "flf2v_720p":
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :]])
if model_type == "flf2v_720p":
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end != None else image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([image_start[:, None, :, :]])
@ -554,8 +544,8 @@ class WanAny2V:
overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4)
if overlapped_latents != None:
# disabled because looks worse
if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone()
if False and overlapped_latents_frames_num > 1: lat_y[:, :, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
y = torch.concat([msk, lat_y])
lat_y = None
kwargs.update({'clip_fea': clip_context, 'y': y})
@ -586,7 +576,7 @@ class WanAny2V:
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else:
overlapped_latents_frames_num = overlapped_frames_num = 0
if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
latent_keep_frames = []
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
@ -609,6 +599,7 @@ class WanAny2V:
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = torch.zeros_like(input_ref_images)
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
trim_frames = input_ref_images.shape[1]
# Vace
if vace :
@ -633,8 +624,8 @@ class WanAny2V:
context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
if overlapped_latents != None :
overlapped_latents_size = overlapped_latents.shape[1]
extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone()
overlapped_latents_size = overlapped_latents.shape[2]
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
@ -649,7 +640,7 @@ class WanAny2V:
from wan.multitalk.multitalk import get_target_masks
audio_proj = [audio.to(self.dtype) for audio in audio_proj]
human_no = len(audio_proj[0])
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None
if fantasy and audio_proj != None:
kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, })
@ -658,8 +649,8 @@ class WanAny2V:
if self._interrupt:
return None
expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes
batch_size = 1
if target_camera != None:
shape = list(target_shape[1:])
shape[0] *= 2
@ -692,20 +683,20 @@ class WanAny2V:
# init denoising
updated_num_steps= len(timesteps)
if callback != None:
from wgp import update_loras_slists
from wan.utils.utils import update_loras_slists
update_loras_slists(self.model, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if sample_scheduler != None:
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
# b, c, lat_f, lat_h, lat_w
latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if apg_switch != 0:
apg_momentum = -0.75
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
# self.image_outputs = False
# denoising
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
@ -715,36 +706,36 @@ class WanAny2V:
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
sigma = t / 1000
noise = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start:
new_latents = latents.clone()
new_latents[:, :source_latents.shape[1] ] = noise[:, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents
new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0)
for latent_no, keep_latent in enumerate(latent_keep_frames):
if not keep_latent:
new_latents[:, latent_no:latent_no+1 ] = latents[:, latent_no:latent_no+1]
new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
latents = new_latents
new_latents = None
else:
latents = noise * sigma + (1 - sigma) * source_latents
latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0)
noise = None
if extended_overlapped_latents != None:
latent_noise_factor = t / 1000
latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
if vace:
overlap_noise_factor = overlap_noise / 1000
for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
if target_camera != None:
latent_model_input = torch.cat([latents, source_latents], dim=1)
latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!!
else:
latent_model_input = latents
if phantom:
gen_args = {
"x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 +
[ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]),
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"context": [context, context_null, context_null] ,
}
elif fantasy:
@ -753,7 +744,7 @@ class WanAny2V:
"context" : [context, context_null, context_null],
"audio_scale": [audio_scale, None, None ]
}
elif multitalk:
elif multitalk and audio_proj != None:
gen_args = {
"x" : [latent_model_input, latent_model_input, latent_model_input],
"context" : [context, context_null, context_null],
@ -832,38 +823,41 @@ class WanAny2V:
if sample_solver == "euler":
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
dt = dt / self.num_timesteps
latents = latents - noise_pred * dt[:, None, None, None]
latents = latents - noise_pred * dt[:, None, None, None, None]
else:
temp_x0 = sample_scheduler.step(
noise_pred[:, :target_shape[1]].unsqueeze(0),
latents = sample_scheduler.step(
noise_pred[:, :, :target_shape[1]],
t,
latents.unsqueeze(0),
latents,
**scheduler_kwargs)[0]
latents = temp_x0.squeeze(0)
del temp_x0
if callback is not None:
callback(i, latents, False)
latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
callback(i, latents_preview[0], False)
latents_preview = None
x0 = [latents]
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone()
x0 =latents.unbind(dim=0)
if chipmunk:
self.model.release_chipmunk() # need to add it at every exit when in prod
if return_latent_slice != None:
latent_slice = latents[:, return_latent_slice].clone()
if vace:
# vace post processing
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
else:
if phantom and input_ref_images != None:
trim_frames = input_ref_images.shape[1]
if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0]
videos = self.vae.decode(x0, VAE_tile_size)
videos = self.vae.decode(x0, VAE_tile_size)
if self.image_outputs:
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
else:
videos = videos[0] # return only first video
if return_latent_slice != None:
return { "x" : videos[0], "latent_slice" : latent_slice }
return videos[0]
return { "x" : videos, "latent_slice" : latent_slice }
return videos
def adapt_vace_model(self):
model = self.model

View File

@ -19,7 +19,7 @@ from wan.utils.utils import calculate_new_dimensions
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wgp import update_loras_slists
from wan.utils.utils import update_loras_slists
class DTT2V:
@ -31,6 +31,7 @@ class DTT2V:
rank=0,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
save_quantized = False,
text_encoder_filename = None,
@ -53,6 +54,8 @@ class DTT2V:
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None)
self.model_def = model_def
self.image_outputs = model_def.get("image_outputs", False)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
@ -202,6 +205,7 @@ class DTT2V:
width: int = 832,
fit_into_canvas = True,
frame_num: int = 97,
batch_size = 1,
sampling_steps: int = 50,
shift: float = 1.0,
guide_scale: float = 5.0,
@ -224,8 +228,9 @@ class DTT2V:
generator = torch.Generator(device=self.device)
generator.manual_seed(seed)
self._guidance_scale = guide_scale
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
if frame_num > 1:
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
if ar_step == 0:
causal_block_size = 1
@ -244,7 +249,7 @@ class DTT2V:
image_start = np.array(image_start.resize((width, height))).transpose(2, 0, 1)
latent_length = (frame_num - 1) // 4 + 1
latent_length = (frame_num - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8
@ -297,12 +302,12 @@ class DTT2V:
prefix_video = prefix_video[:, : predix_video_latent_length]
base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latent_shape = [batch_size, 16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
if prefix_video is not None:
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
latents[:, :, :predix_video_latent_length] = prefix_video.to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
@ -340,7 +345,7 @@ class DTT2V:
else:
self.model.enable_cache = None
from mmgp import offload
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False)
kwrags = {
"freqs" :freqs,
"fps" : fps_embeds,
@ -358,15 +363,15 @@ class DTT2V:
update_mask_i = step_update_mask[i]
valid_interval_start, valid_interval_end = valid_interval[i]
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
latent_model_input = latents[:, :, valid_interval_start:valid_interval_end, :, :].clone()
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * overlap_noise
timestep_for_noised_condition = overlap_noise
latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[:, valid_interval_start:predix_video_latent_length]
latent_model_input[:, :, valid_interval_start:predix_video_latent_length] = (
latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[:, valid_interval_start:predix_video_latent_length]
latent_model_input[:, :, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
@ -417,18 +422,27 @@ class DTT2V:
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
latents[:, :, idx] = sample_schedulers[idx].step(
noise_pred[:, :, idx - valid_interval_start],
timestep_i[idx],
latents[:, idx],
latents[:, :, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents.squeeze(0), False)
latents_preview = latents
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
callback(i, latents_preview[0], False)
latents_preview = None
x0 = latents.unsqueeze(0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
return output_video
x0 =latents.unbind(dim=0)
videos = self.vae.decode(x0, VAE_tile_size)
if self.image_outputs:
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
else:
videos = videos[0] # return only first video
return videos

View File

@ -185,7 +185,7 @@ def pay_attention(
q,k,v = qkv_list
qkv_list.clear()
out_dtype = q.dtype
if q.dtype == torch.bfloat16 and not bfloat16_supported:
if q.dtype == torch.bfloat16 and not bfloat16_supported:
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
@ -194,7 +194,9 @@ def pay_attention(
q = q.to(v.dtype)
k = k.to(v.dtype)
batch = len(q)
if len(k) != batch: k = k.expand(batch, -1, -1, -1)
if len(v) != batch: v = v.expand(batch, -1, -1, -1)
if attn == "chipmunk":
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG

View File

@ -33,9 +33,10 @@ def sinusoidal_embedding_1d(dim, position):
def reshape_latent(latent, latent_frames):
if latent_frames == latent.shape[0]:
return latent
return latent.reshape(latent_frames, -1, latent.shape[-1] )
return latent.reshape(latent.shape[0], latent_frames, -1, latent.shape[-1] )
def restore_latent_shape(latent):
return latent.reshape(latent.shape[0], -1, latent.shape[-1] )
def identify_k( b: float, d: int, N: int):
@ -493,7 +494,7 @@ class WanAttentionBlock(nn.Module):
x_mod = reshape_latent(x_mod , latent_frames)
x_mod *= 1 + e[1]
x_mod += e[0]
x_mod = reshape_latent(x_mod , 1)
x_mod = restore_latent_shape(x_mod)
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
@ -510,7 +511,7 @@ class WanAttentionBlock(nn.Module):
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[2])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
x, y = restore_latent_shape(x), restore_latent_shape(y)
del y
y = self.norm3(x)
y = y.to(attention_dtype)
@ -542,7 +543,7 @@ class WanAttentionBlock(nn.Module):
y = reshape_latent(y , latent_frames)
y *= 1 + e[4]
y += e[3]
y = reshape_latent(y , 1)
y = restore_latent_shape(y)
y = y.to(attention_dtype)
ffn = self.ffn[0]
@ -562,7 +563,7 @@ class WanAttentionBlock(nn.Module):
y = y.to(dtype)
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
x, y = restore_latent_shape(x), restore_latent_shape(y)
if hints_processed is not None:
for hint, scale in zip(hints_processed, context_scale):
@ -669,6 +670,8 @@ class VaceWanAttentionBlock(WanAttentionBlock):
hints[0] = None
if self.block_id == 0:
c = self.before_proj(c)
bz = x.shape[0]
if bz > c.shape[0]: c = c.repeat(bz, 1, 1 )
c += x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
@ -707,7 +710,7 @@ class Head(nn.Module):
x = reshape_latent(x , latent_frames)
x *= (1 + e[1])
x += e[0]
x = reshape_latent(x , 1)
x = restore_latent_shape(x)
x= x.to(self.head.weight.dtype)
x = self.head(x)
return x
@ -1162,11 +1165,15 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[i] = x_list[0].clone()
last_x_idx = i
else:
# image source
# image source
bz = len(x)
if y is not None:
x = torch.cat([x, y], dim=0)
y = y.unsqueeze(0)
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1)
# embeddings
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:]
if chipmunk:
x = x.unsqueeze(-1)
@ -1204,7 +1211,7 @@ class WanModel(ModelMixin, ConfigMixin):
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info:
if self.inject_sample_info and fps!=None:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).to(e.dtype)
@ -1402,7 +1409,7 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[i] = self.unpatchify(x, grid_sizes)
del x
return [x[0].float() for x in x_list]
return [x.float() for x in x_list]
def unpatchify(self, x, grid_sizes):
r"""
@ -1427,7 +1434,10 @@ class WanModel(ModelMixin, ConfigMixin):
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
out.append(u)
return out
if len(x) == 1:
return out[0].unsqueeze(0)
else:
return torch.stack(out, 0)
def init_weights(self):
r"""

View File

@ -333,7 +333,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype, device=human1.device)
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
@ -351,7 +351,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k)
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0)

View File

@ -184,6 +184,7 @@ def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combinat
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
if full_audio_embs == None: return None
HUMAN_NUMBER = len(full_audio_embs)
audio_end_idx = audio_start_idx + clip_length
indices = (torch.arange(2 * 2 + 1) - 2) * 1
@ -271,6 +272,34 @@ def timestep_transform(
new_t = new_t * num_timesteps
return new_t
def parse_speakers_locations(speakers_locations):
bbox = {}
if speakers_locations is None or len(speakers_locations) == 0:
return None, ""
speakers = speakers_locations.split(" ")
if len(speakers) !=2:
error= "Two speakers locations should be defined"
return "", error
for i, speaker in enumerate(speakers):
location = speaker.strip().split(":")
if len(location) not in (2,4):
error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom"
return "", error
try:
good = False
location_float = [ float(val) for val in location]
good = all( 0 <= val <= 100 for val in location_float)
except:
pass
if not good:
error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100."
return "", error
if len(location_float) == 2:
location_float = [location_float[0], 0, location_float[1], 100]
bbox[f"human{i}"] = location_float
return bbox, ""
# construct human mask
def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None):
@ -285,7 +314,9 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio"
background_mask = torch.zeros([src_h, src_w])
for _, person_bbox in bbox.items():
x_min, y_min, x_max, y_max = person_bbox
y_min, x_min, y_max, x_max = person_bbox
x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95)
x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100)
human_mask = torch.zeros([src_h, src_w])
human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
background_mask += human_mask
@ -305,7 +336,7 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05
human_masks = [human_mask1, human_mask2]
background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1))
human_masks.append(background_mask)
# toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy())
ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device)
# resize and centercrop for ref_target_masks
# ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w))

View File

@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device)
split_chunk = heads // split_num
@ -350,4 +350,4 @@ def adaptive_projected_guidance(
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
return normalized_update
return normalized_update

View File

@ -53,7 +53,7 @@ class FlowMatchScheduler():
else:
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
return [prev_sample]
def add_noise(self, original_samples, noise, timestep):
"""

View File

@ -5,7 +5,8 @@ import os
import os.path as osp
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import cv2
import tempfile
import imageio
import torch
import decord
@ -32,7 +33,22 @@ def seed_everything(seed: int):
torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def expand_slist(slist, num_inference_steps ):
new_slist= []
inc = len(slist) / num_inference_steps
pos = 0
for i in range(num_inference_steps):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def update_loras_slists(trans, slists, num_inference_steps ):
from mmgp import offload
slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ]
nos = [str(l) for l in range(len(slists))]
offload.activate_loras(trans, nos, slists )
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
@ -101,6 +117,29 @@ def get_video_frame(file_name, frame_no):
img = Image.fromarray(frame.numpy().astype(np.uint8))
return img
def convert_image_to_video(image):
if image is None:
return None
# Convert PIL/numpy image to OpenCV format if needed
if isinstance(image, np.ndarray):
# Gradio images are typically RGB, OpenCV expects BGR
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
# Handle PIL Image
img_array = np.array(image)
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
height, width = img_bgr.shape[:2]
# Create temporary video file (auto-cleaned by Gradio)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height))
out.write(img_bgr)
out.release()
return temp_video.name
def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
@ -454,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False):
except ffmpeg.Error as e:
print(f"FFmpeg error during audio extraction: {e}")
return []
return 0 if query_only else []
except Exception as e:
print(f"Error during audio extraction: {e}")
return []
return 0 if query_only else []
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
"""

1156
wgp.py

File diff suppressed because it is too large Load Diff